import os import torch import torch.nn as nn import torch.optim as optim import numpy as np import argparse import matplotlib.pyplot as plt from config import * from data_preprocessing.data_loader import load_data from models.image_models import VisionTransformer from models.fusion_model import MultiModalFusionModel, SingleModalModel from training.train import train_model, train_single_modal_model from training.utils import plot_training_curves, save_results from evaluation.evaluate import evaluate_model, plot_roc_curve, extract_and_visualize_features, compare_models def main(): # 解析命令行参数 parser = argparse.ArgumentParser(description='白血病智能筛查系统') parser.add_argument('--data_root', type=str, default=DATA_ROOT, help='数据根目录') parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='输出目录') parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='批大小') parser.add_argument('--epochs', type=int, default=NUM_EPOCHS, help='训练轮数') parser.add_argument('--lr', type=float, default=LEARNING_RATE, help='学习率') parser.add_argument('--weight_decay', type=float, default=WEIGHT_DECAY, help='权重衰减') parser.add_argument('--mode', type=str, choices=['train', 'evaluate', 'compare'], default='train', help='运行模式') args = parser.parse_args() # 确保输出目录存在 os.makedirs(args.output_dir, exist_ok=True) # 加载数据 print("加载数据...") train_loader, val_loader = load_data( data_root=args.data_root, img_size=IMG_SIZE, batch_size=args.batch_size, num_workers=NUM_WORKERS, train_ratio=TRAIN_RATIO, max_samples_per_class=MAX_SAMPLES_PER_CLASS, normal_class=NORMAL_CLASS, abnormal_classes=ABNORMAL_CLASSES ) print(f"数据加载完成。训练集批次数: {len(train_loader)}, 验证集批次数: {len(val_loader)}") # 设置设备 device = DEVICE print(f"使用设备: {device}") if args.mode == 'train': # 创建多模态模型 print("创建多模态融合模型...") multi_modal_model = MultiModalFusionModel( img_size=IMG_SIZE, patch_size=16, in_channels=3, embed_dim=HIDDEN_DIM, depth=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT, num_classes=NUM_CLASSES ).to(device) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # 训练多模态模型 print("开始训练多模态融合模型...") multi_modal_model, multi_modal_history = train_model( model=multi_modal_model, train_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=optimizer, device=device, num_epochs=args.epochs, save_dir=MODEL_SAVE_DIR, model_name='multi_modal' ) # 可视化训练历史 plot_training_curves( multi_modal_history['train_losses'], multi_modal_history['val_losses'], multi_modal_history['train_accs'], multi_modal_history['val_accs'], save_path=os.path.join(args.output_dir, 'multi_modal_training_curves.png') ) # 评估多模态模型 print("\n评估多模态融合模型...") multi_modal_metrics, labels, preds, probs = evaluate_model( model=multi_modal_model, data_loader=val_loader, device=device, class_names=['正常', '异常'] ) # 绘制ROC曲线 plot_roc_curve( labels, probs, save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png') ) # 提取和可视化特征 extract_and_visualize_features( model=multi_modal_model, data_loader=val_loader, device=device, save_path=os.path.join(args.output_dir, 'feature_visualization.png') ) # 保存结果 save_results( multi_modal_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt') ) elif args.mode == 'compare': print("创建模型进行比较...") # 创建单模态模型 - DIFF diff_model = SingleModalModel( img_size=IMG_SIZE, patch_size=16, in_channels=3, embed_dim=HIDDEN_DIM, depth=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT, num_classes=NUM_CLASSES ).to(device) # 创建单模态模型 - WNB wnb_model = SingleModalModel( img_size=IMG_SIZE, patch_size=16, in_channels=3, embed_dim=HIDDEN_DIM, depth=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT, num_classes=NUM_CLASSES ).to(device) # 创建多模态模型 multi_modal_model = MultiModalFusionModel( img_size=IMG_SIZE, patch_size=16, in_channels=3, embed_dim=HIDDEN_DIM, depth=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT, num_classes=NUM_CLASSES ).to(device) # 定义损失函数 criterion = nn.CrossEntropyLoss() # 训练DIFF单模态模型 print("训练DIFF散点图单模态模型...") diff_optimizer = optim.AdamW(diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) diff_model, diff_history = train_single_modal_model( model=diff_model, train_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=diff_optimizer, device=device, num_epochs=args.epochs, save_dir=MODEL_SAVE_DIR, model_name='diff_only', modal_key='diff_img' ) # 训练WNB单模态模型 print("训练WNB散点图单模态模型...") wnb_optimizer = optim.AdamW(wnb_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) wnb_model, wnb_history = train_single_modal_model( model=wnb_model, train_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=wnb_optimizer, device=device, num_epochs=args.epochs, save_dir=MODEL_SAVE_DIR, model_name='wnb_only', modal_key='wnb_img' ) # 训练多模态模型 print("训练多模态融合模型...") multi_optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) multi_modal_model, multi_history = train_model( model=multi_modal_model, train_loader=train_loader, val_loader=val_loader, criterion=criterion, optimizer=multi_optimizer, device=device, num_epochs=args.epochs, save_dir=MODEL_SAVE_DIR, model_name='multi_modal' ) # 评估并比较模型 print("\n评估DIFF散点图单模态模型...") diff_metrics, _, _, _ = evaluate_model( model=diff_model, data_loader=val_loader, device=device, class_names=['正常', '异常'] ) print("\n评估WNB散点图单模态模型...") wnb_metrics, _, _, _ = evaluate_model( model=wnb_model, data_loader=val_loader, device=device, class_names=['正常', '异常'] ) print("\n评估多模态融合模型...") multi_metrics, _, _, _ = evaluate_model( model=multi_modal_model, data_loader=val_loader, device=device, class_names=['正常', '异常'] ) # 比较不同模型的性能 compare_models( [diff_metrics, wnb_metrics, multi_metrics], ['DIFF散点图', 'WNB散点图', '多模态融合'], save_path=os.path.join(args.output_dir, 'model_comparison.png') ) # 可视化训练曲线 plt.figure(figsize=(12, 8)) plt.subplot(2, 2, 1) plt.plot(diff_history['train_losses'], label='DIFF训练') plt.plot(wnb_history['train_losses'], label='WNB训练') plt.plot(multi_history['train_losses'], label='多模态训练') plt.title('训练损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) plt.subplot(2, 2, 2) plt.plot(diff_history['val_losses'], label='DIFF验证') plt.plot(wnb_history['val_losses'], label='WNB验证') plt.plot(multi_history['val_losses'], label='多模态验证') plt.title('验证损失') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) plt.subplot(2, 2, 3) plt.plot(diff_history['train_accs'], label='DIFF训练') plt.plot(wnb_history['train_accs'], label='WNB训练') plt.plot(multi_history['train_accs'], label='多模态训练') plt.title('训练准确率') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) plt.subplot(2, 2, 4) plt.plot(diff_history['val_accs'], label='DIFF验证') plt.plot(wnb_history['val_accs'], label='WNB验证') plt.plot(multi_history['val_accs'], label='多模态验证') plt.title('验证准确率') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig(os.path.join(args.output_dir, 'all_models_training_curves.png')) plt.show() # 保存结果 save_results(diff_metrics, os.path.join(args.output_dir, 'diff_model_results.txt')) save_results(wnb_metrics, os.path.join(args.output_dir, 'wnb_model_results.txt')) save_results(multi_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt')) elif args.mode == 'evaluate': # 加载预训练的多模态模型 print("加载预训练的多模态模型...") model_path = os.path.join(MODEL_SAVE_DIR, 'multi_modal_best.pth') if not os.path.exists(model_path): print(f"错误:找不到预训练模型 {model_path}") return multi_modal_model = MultiModalFusionModel( img_size=IMG_SIZE, patch_size=16, in_channels=3, embed_dim=HIDDEN_DIM, depth=NUM_LAYERS, num_heads=NUM_HEADS, dropout=DROPOUT, num_classes=NUM_CLASSES ).to(device) multi_modal_model.load_state_dict(torch.load(model_path)) # 评估模型 print("评估多模态融合模型...") multi_modal_metrics, labels, preds, probs = evaluate_model( model=multi_modal_model, data_loader=val_loader, device=device, class_names=['正常', '异常'] ) # 绘制ROC曲线 plot_roc_curve( labels, probs, save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png') ) # 提取和可视化特征 extract_and_visualize_features( model=multi_modal_model, data_loader=val_loader, device=device, save_path=os.path.join(args.output_dir, 'feature_visualization.png') ) # 保存结果 save_results( multi_modal_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt') ) if __name__ == "__main__": main()