349 lines
12 KiB
Python
349 lines
12 KiB
Python
|
|
import os
|
||
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import torch.optim as optim
|
||
|
|
import numpy as np
|
||
|
|
import argparse
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
from config import *
|
||
|
|
|
||
|
|
from data_preprocessing.data_loader import load_data
|
||
|
|
from models.image_models import VisionTransformer
|
||
|
|
from models.fusion_model import MultiModalFusionModel, SingleModalModel
|
||
|
|
from training.train import train_model, train_single_modal_model
|
||
|
|
from training.utils import plot_training_curves, save_results
|
||
|
|
from evaluation.evaluate import evaluate_model, plot_roc_curve, extract_and_visualize_features, compare_models
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
# 解析命令行参数
|
||
|
|
parser = argparse.ArgumentParser(description='白血病智能筛查系统')
|
||
|
|
parser.add_argument('--data_root', type=str, default=DATA_ROOT, help='数据根目录')
|
||
|
|
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='输出目录')
|
||
|
|
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='批大小')
|
||
|
|
parser.add_argument('--epochs', type=int, default=NUM_EPOCHS, help='训练轮数')
|
||
|
|
parser.add_argument('--lr', type=float, default=LEARNING_RATE, help='学习率')
|
||
|
|
parser.add_argument('--weight_decay', type=float, default=WEIGHT_DECAY, help='权重衰减')
|
||
|
|
parser.add_argument('--mode', type=str, choices=['train', 'evaluate', 'compare'], default='train', help='运行模式')
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
# 确保输出目录存在
|
||
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
||
|
|
|
||
|
|
# 加载数据
|
||
|
|
print("加载数据...")
|
||
|
|
train_loader, val_loader = load_data(
|
||
|
|
data_root=args.data_root,
|
||
|
|
img_size=IMG_SIZE,
|
||
|
|
batch_size=args.batch_size,
|
||
|
|
num_workers=NUM_WORKERS,
|
||
|
|
train_ratio=TRAIN_RATIO,
|
||
|
|
max_samples_per_class=MAX_SAMPLES_PER_CLASS,
|
||
|
|
normal_class=NORMAL_CLASS,
|
||
|
|
abnormal_classes=ABNORMAL_CLASSES
|
||
|
|
)
|
||
|
|
print(f"数据加载完成。训练集批次数: {len(train_loader)}, 验证集批次数: {len(val_loader)}")
|
||
|
|
|
||
|
|
# 设置设备
|
||
|
|
device = DEVICE
|
||
|
|
print(f"使用设备: {device}")
|
||
|
|
|
||
|
|
if args.mode == 'train':
|
||
|
|
# 创建多模态模型
|
||
|
|
print("创建多模态融合模型...")
|
||
|
|
multi_modal_model = MultiModalFusionModel(
|
||
|
|
img_size=IMG_SIZE,
|
||
|
|
patch_size=16,
|
||
|
|
in_channels=3,
|
||
|
|
embed_dim=HIDDEN_DIM,
|
||
|
|
depth=NUM_LAYERS,
|
||
|
|
num_heads=NUM_HEADS,
|
||
|
|
dropout=DROPOUT,
|
||
|
|
num_classes=NUM_CLASSES
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
# 定义损失函数和优化器
|
||
|
|
criterion = nn.CrossEntropyLoss()
|
||
|
|
optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||
|
|
|
||
|
|
# 训练多模态模型
|
||
|
|
print("开始训练多模态融合模型...")
|
||
|
|
multi_modal_model, multi_modal_history = train_model(
|
||
|
|
model=multi_modal_model,
|
||
|
|
train_loader=train_loader,
|
||
|
|
val_loader=val_loader,
|
||
|
|
criterion=criterion,
|
||
|
|
optimizer=optimizer,
|
||
|
|
device=device,
|
||
|
|
num_epochs=args.epochs,
|
||
|
|
save_dir=MODEL_SAVE_DIR,
|
||
|
|
model_name='multi_modal'
|
||
|
|
)
|
||
|
|
|
||
|
|
# 可视化训练历史
|
||
|
|
plot_training_curves(
|
||
|
|
multi_modal_history['train_losses'],
|
||
|
|
multi_modal_history['val_losses'],
|
||
|
|
multi_modal_history['train_accs'],
|
||
|
|
multi_modal_history['val_accs'],
|
||
|
|
save_path=os.path.join(args.output_dir, 'multi_modal_training_curves.png')
|
||
|
|
)
|
||
|
|
|
||
|
|
# 评估多模态模型
|
||
|
|
print("\n评估多模态融合模型...")
|
||
|
|
multi_modal_metrics, labels, preds, probs = evaluate_model(
|
||
|
|
model=multi_modal_model,
|
||
|
|
data_loader=val_loader,
|
||
|
|
device=device,
|
||
|
|
class_names=['正常', '异常']
|
||
|
|
)
|
||
|
|
|
||
|
|
# 绘制ROC曲线
|
||
|
|
plot_roc_curve(
|
||
|
|
labels,
|
||
|
|
probs,
|
||
|
|
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
|
||
|
|
)
|
||
|
|
|
||
|
|
# 提取和可视化特征
|
||
|
|
extract_and_visualize_features(
|
||
|
|
model=multi_modal_model,
|
||
|
|
data_loader=val_loader,
|
||
|
|
device=device,
|
||
|
|
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
|
||
|
|
)
|
||
|
|
|
||
|
|
# 保存结果
|
||
|
|
save_results(
|
||
|
|
multi_modal_metrics,
|
||
|
|
os.path.join(args.output_dir, 'multi_modal_results.txt')
|
||
|
|
)
|
||
|
|
|
||
|
|
elif args.mode == 'compare':
|
||
|
|
print("创建模型进行比较...")
|
||
|
|
|
||
|
|
# 创建单模态模型 - DIFF
|
||
|
|
diff_model = SingleModalModel(
|
||
|
|
img_size=IMG_SIZE,
|
||
|
|
patch_size=16,
|
||
|
|
in_channels=3,
|
||
|
|
embed_dim=HIDDEN_DIM,
|
||
|
|
depth=NUM_LAYERS,
|
||
|
|
num_heads=NUM_HEADS,
|
||
|
|
dropout=DROPOUT,
|
||
|
|
num_classes=NUM_CLASSES
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
# 创建单模态模型 - WNB
|
||
|
|
wnb_model = SingleModalModel(
|
||
|
|
img_size=IMG_SIZE,
|
||
|
|
patch_size=16,
|
||
|
|
in_channels=3,
|
||
|
|
embed_dim=HIDDEN_DIM,
|
||
|
|
depth=NUM_LAYERS,
|
||
|
|
num_heads=NUM_HEADS,
|
||
|
|
dropout=DROPOUT,
|
||
|
|
num_classes=NUM_CLASSES
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
# 创建多模态模型
|
||
|
|
multi_modal_model = MultiModalFusionModel(
|
||
|
|
img_size=IMG_SIZE,
|
||
|
|
patch_size=16,
|
||
|
|
in_channels=3,
|
||
|
|
embed_dim=HIDDEN_DIM,
|
||
|
|
depth=NUM_LAYERS,
|
||
|
|
num_heads=NUM_HEADS,
|
||
|
|
dropout=DROPOUT,
|
||
|
|
num_classes=NUM_CLASSES
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
# 定义损失函数
|
||
|
|
criterion = nn.CrossEntropyLoss()
|
||
|
|
|
||
|
|
# 训练DIFF单模态模型
|
||
|
|
print("训练DIFF散点图单模态模型...")
|
||
|
|
diff_optimizer = optim.AdamW(diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||
|
|
diff_model, diff_history = train_single_modal_model(
|
||
|
|
model=diff_model,
|
||
|
|
train_loader=train_loader,
|
||
|
|
val_loader=val_loader,
|
||
|
|
criterion=criterion,
|
||
|
|
optimizer=diff_optimizer,
|
||
|
|
device=device,
|
||
|
|
num_epochs=args.epochs,
|
||
|
|
save_dir=MODEL_SAVE_DIR,
|
||
|
|
model_name='diff_only',
|
||
|
|
modal_key='diff_img'
|
||
|
|
)
|
||
|
|
|
||
|
|
# 训练WNB单模态模型
|
||
|
|
print("训练WNB散点图单模态模型...")
|
||
|
|
wnb_optimizer = optim.AdamW(wnb_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||
|
|
wnb_model, wnb_history = train_single_modal_model(
|
||
|
|
model=wnb_model,
|
||
|
|
train_loader=train_loader,
|
||
|
|
val_loader=val_loader,
|
||
|
|
criterion=criterion,
|
||
|
|
optimizer=wnb_optimizer,
|
||
|
|
device=device,
|
||
|
|
num_epochs=args.epochs,
|
||
|
|
save_dir=MODEL_SAVE_DIR,
|
||
|
|
model_name='wnb_only',
|
||
|
|
modal_key='wnb_img'
|
||
|
|
)
|
||
|
|
|
||
|
|
# 训练多模态模型
|
||
|
|
print("训练多模态融合模型...")
|
||
|
|
multi_optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||
|
|
multi_modal_model, multi_history = train_model(
|
||
|
|
model=multi_modal_model,
|
||
|
|
train_loader=train_loader,
|
||
|
|
val_loader=val_loader,
|
||
|
|
criterion=criterion,
|
||
|
|
optimizer=multi_optimizer,
|
||
|
|
device=device,
|
||
|
|
num_epochs=args.epochs,
|
||
|
|
save_dir=MODEL_SAVE_DIR,
|
||
|
|
model_name='multi_modal'
|
||
|
|
)
|
||
|
|
|
||
|
|
# 评估并比较模型
|
||
|
|
print("\n评估DIFF散点图单模态模型...")
|
||
|
|
diff_metrics, _, _, _ = evaluate_model(
|
||
|
|
model=diff_model,
|
||
|
|
data_loader=val_loader,
|
||
|
|
device=device,
|
||
|
|
class_names=['正常', '异常']
|
||
|
|
)
|
||
|
|
|
||
|
|
print("\n评估WNB散点图单模态模型...")
|
||
|
|
wnb_metrics, _, _, _ = evaluate_model(
|
||
|
|
model=wnb_model,
|
||
|
|
data_loader=val_loader,
|
||
|
|
device=device,
|
||
|
|
class_names=['正常', '异常']
|
||
|
|
)
|
||
|
|
|
||
|
|
print("\n评估多模态融合模型...")
|
||
|
|
multi_metrics, _, _, _ = evaluate_model(
|
||
|
|
model=multi_modal_model,
|
||
|
|
data_loader=val_loader,
|
||
|
|
device=device,
|
||
|
|
class_names=['正常', '异常']
|
||
|
|
)
|
||
|
|
|
||
|
|
# 比较不同模型的性能
|
||
|
|
compare_models(
|
||
|
|
[diff_metrics, wnb_metrics, multi_metrics],
|
||
|
|
['DIFF散点图', 'WNB散点图', '多模态融合'],
|
||
|
|
save_path=os.path.join(args.output_dir, 'model_comparison.png')
|
||
|
|
)
|
||
|
|
|
||
|
|
# 可视化训练曲线
|
||
|
|
plt.figure(figsize=(12, 8))
|
||
|
|
|
||
|
|
plt.subplot(2, 2, 1)
|
||
|
|
plt.plot(diff_history['train_losses'], label='DIFF训练')
|
||
|
|
plt.plot(wnb_history['train_losses'], label='WNB训练')
|
||
|
|
plt.plot(multi_history['train_losses'], label='多模态训练')
|
||
|
|
plt.title('训练损失')
|
||
|
|
plt.xlabel('Epoch')
|
||
|
|
plt.ylabel('损失')
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True)
|
||
|
|
|
||
|
|
plt.subplot(2, 2, 2)
|
||
|
|
plt.plot(diff_history['val_losses'], label='DIFF验证')
|
||
|
|
plt.plot(wnb_history['val_losses'], label='WNB验证')
|
||
|
|
plt.plot(multi_history['val_losses'], label='多模态验证')
|
||
|
|
plt.title('验证损失')
|
||
|
|
plt.xlabel('Epoch')
|
||
|
|
plt.ylabel('损失')
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True)
|
||
|
|
|
||
|
|
plt.subplot(2, 2, 3)
|
||
|
|
plt.plot(diff_history['train_accs'], label='DIFF训练')
|
||
|
|
plt.plot(wnb_history['train_accs'], label='WNB训练')
|
||
|
|
plt.plot(multi_history['train_accs'], label='多模态训练')
|
||
|
|
plt.title('训练准确率')
|
||
|
|
plt.xlabel('Epoch')
|
||
|
|
plt.ylabel('准确率')
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True)
|
||
|
|
|
||
|
|
plt.subplot(2, 2, 4)
|
||
|
|
plt.plot(diff_history['val_accs'], label='DIFF验证')
|
||
|
|
plt.plot(wnb_history['val_accs'], label='WNB验证')
|
||
|
|
plt.plot(multi_history['val_accs'], label='多模态验证')
|
||
|
|
plt.title('验证准确率')
|
||
|
|
plt.xlabel('Epoch')
|
||
|
|
plt.ylabel('准确率')
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True)
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
plt.savefig(os.path.join(args.output_dir, 'all_models_training_curves.png'))
|
||
|
|
plt.show()
|
||
|
|
|
||
|
|
# 保存结果
|
||
|
|
save_results(diff_metrics, os.path.join(args.output_dir, 'diff_model_results.txt'))
|
||
|
|
save_results(wnb_metrics, os.path.join(args.output_dir, 'wnb_model_results.txt'))
|
||
|
|
save_results(multi_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt'))
|
||
|
|
|
||
|
|
elif args.mode == 'evaluate':
|
||
|
|
# 加载预训练的多模态模型
|
||
|
|
print("加载预训练的多模态模型...")
|
||
|
|
model_path = os.path.join(MODEL_SAVE_DIR, 'multi_modal_best.pth')
|
||
|
|
|
||
|
|
if not os.path.exists(model_path):
|
||
|
|
print(f"错误:找不到预训练模型 {model_path}")
|
||
|
|
return
|
||
|
|
|
||
|
|
multi_modal_model = MultiModalFusionModel(
|
||
|
|
img_size=IMG_SIZE,
|
||
|
|
patch_size=16,
|
||
|
|
in_channels=3,
|
||
|
|
embed_dim=HIDDEN_DIM,
|
||
|
|
depth=NUM_LAYERS,
|
||
|
|
num_heads=NUM_HEADS,
|
||
|
|
dropout=DROPOUT,
|
||
|
|
num_classes=NUM_CLASSES
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
multi_modal_model.load_state_dict(torch.load(model_path))
|
||
|
|
|
||
|
|
# 评估模型
|
||
|
|
print("评估多模态融合模型...")
|
||
|
|
multi_modal_metrics, labels, preds, probs = evaluate_model(
|
||
|
|
model=multi_modal_model,
|
||
|
|
data_loader=val_loader,
|
||
|
|
device=device,
|
||
|
|
class_names=['正常', '异常']
|
||
|
|
)
|
||
|
|
|
||
|
|
# 绘制ROC曲线
|
||
|
|
plot_roc_curve(
|
||
|
|
labels,
|
||
|
|
probs,
|
||
|
|
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
|
||
|
|
)
|
||
|
|
|
||
|
|
# 提取和可视化特征
|
||
|
|
extract_and_visualize_features(
|
||
|
|
model=multi_modal_model,
|
||
|
|
data_loader=val_loader,
|
||
|
|
device=device,
|
||
|
|
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
|
||
|
|
)
|
||
|
|
|
||
|
|
# 保存结果
|
||
|
|
save_results(
|
||
|
|
multi_modal_metrics,
|
||
|
|
os.path.join(args.output_dir, 'multi_modal_results.txt')
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|