349 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			349 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
import torch
 | 
						|
import torch.nn as nn
 | 
						|
import torch.optim as optim
 | 
						|
import numpy as np
 | 
						|
import argparse
 | 
						|
import matplotlib.pyplot as plt
 | 
						|
from config import *
 | 
						|
 | 
						|
from data_preprocessing.data_loader import load_data
 | 
						|
from models.image_models import VisionTransformer
 | 
						|
from models.fusion_model import MultiModalFusionModel, SingleModalModel
 | 
						|
from training.train import train_model, train_single_modal_model
 | 
						|
from training.utils import plot_training_curves, save_results
 | 
						|
from evaluation.evaluate import evaluate_model, plot_roc_curve, extract_and_visualize_features, compare_models
 | 
						|
 | 
						|
 | 
						|
def main():
 | 
						|
    # 解析命令行参数
 | 
						|
    parser = argparse.ArgumentParser(description='白血病智能筛查系统')
 | 
						|
    parser.add_argument('--data_root', type=str, default=DATA_ROOT, help='数据根目录')
 | 
						|
    parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='输出目录')
 | 
						|
    parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='批大小')
 | 
						|
    parser.add_argument('--epochs', type=int, default=NUM_EPOCHS, help='训练轮数')
 | 
						|
    parser.add_argument('--lr', type=float, default=LEARNING_RATE, help='学习率')
 | 
						|
    parser.add_argument('--weight_decay', type=float, default=WEIGHT_DECAY, help='权重衰减')
 | 
						|
    parser.add_argument('--mode', type=str, choices=['train', 'evaluate', 'compare'], default='train', help='运行模式')
 | 
						|
    args = parser.parse_args()
 | 
						|
    
 | 
						|
    # 确保输出目录存在
 | 
						|
    os.makedirs(args.output_dir, exist_ok=True)
 | 
						|
    
 | 
						|
    # 加载数据
 | 
						|
    print("加载数据...")
 | 
						|
    train_loader, val_loader = load_data(
 | 
						|
        data_root=args.data_root,
 | 
						|
        img_size=IMG_SIZE,
 | 
						|
        batch_size=args.batch_size,
 | 
						|
        num_workers=NUM_WORKERS,
 | 
						|
        train_ratio=TRAIN_RATIO,
 | 
						|
        max_samples_per_class=MAX_SAMPLES_PER_CLASS,
 | 
						|
        normal_class=NORMAL_CLASS,
 | 
						|
        abnormal_classes=ABNORMAL_CLASSES
 | 
						|
    )
 | 
						|
    print(f"数据加载完成。训练集批次数: {len(train_loader)}, 验证集批次数: {len(val_loader)}")
 | 
						|
    
 | 
						|
    # 设置设备
 | 
						|
    device = DEVICE
 | 
						|
    print(f"使用设备: {device}")
 | 
						|
    
 | 
						|
    if args.mode == 'train':
 | 
						|
        # 创建多模态模型
 | 
						|
        print("创建多模态融合模型...")
 | 
						|
        multi_modal_model = MultiModalFusionModel(
 | 
						|
            img_size=IMG_SIZE,
 | 
						|
            patch_size=16,
 | 
						|
            in_channels=3,
 | 
						|
            embed_dim=HIDDEN_DIM,
 | 
						|
            depth=NUM_LAYERS,
 | 
						|
            num_heads=NUM_HEADS,
 | 
						|
            dropout=DROPOUT,
 | 
						|
            num_classes=NUM_CLASSES
 | 
						|
        ).to(device)
 | 
						|
        
 | 
						|
        # 定义损失函数和优化器
 | 
						|
        criterion = nn.CrossEntropyLoss()
 | 
						|
        optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
 | 
						|
        
 | 
						|
        # 训练多模态模型
 | 
						|
        print("开始训练多模态融合模型...")
 | 
						|
        multi_modal_model, multi_modal_history = train_model(
 | 
						|
            model=multi_modal_model,
 | 
						|
            train_loader=train_loader,
 | 
						|
            val_loader=val_loader,
 | 
						|
            criterion=criterion,
 | 
						|
            optimizer=optimizer,
 | 
						|
            device=device,
 | 
						|
            num_epochs=args.epochs,
 | 
						|
            save_dir=MODEL_SAVE_DIR,
 | 
						|
            model_name='multi_modal'
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 可视化训练历史
 | 
						|
        plot_training_curves(
 | 
						|
            multi_modal_history['train_losses'],
 | 
						|
            multi_modal_history['val_losses'],
 | 
						|
            multi_modal_history['train_accs'],
 | 
						|
            multi_modal_history['val_accs'],
 | 
						|
            save_path=os.path.join(args.output_dir, 'multi_modal_training_curves.png')
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 评估多模态模型
 | 
						|
        print("\n评估多模态融合模型...")
 | 
						|
        multi_modal_metrics, labels, preds, probs = evaluate_model(
 | 
						|
            model=multi_modal_model,
 | 
						|
            data_loader=val_loader,
 | 
						|
            device=device,
 | 
						|
            class_names=['正常', '异常']
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 绘制ROC曲线
 | 
						|
        plot_roc_curve(
 | 
						|
            labels, 
 | 
						|
            probs,
 | 
						|
            save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 提取和可视化特征
 | 
						|
        extract_and_visualize_features(
 | 
						|
            model=multi_modal_model,
 | 
						|
            data_loader=val_loader,
 | 
						|
            device=device,
 | 
						|
            save_path=os.path.join(args.output_dir, 'feature_visualization.png')
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 保存结果
 | 
						|
        save_results(
 | 
						|
            multi_modal_metrics,
 | 
						|
            os.path.join(args.output_dir, 'multi_modal_results.txt')
 | 
						|
        )
 | 
						|
        
 | 
						|
    elif args.mode == 'compare':
 | 
						|
        print("创建模型进行比较...")
 | 
						|
        
 | 
						|
        # 创建单模态模型 - DIFF
 | 
						|
        diff_model = SingleModalModel(
 | 
						|
            img_size=IMG_SIZE,
 | 
						|
            patch_size=16,
 | 
						|
            in_channels=3,
 | 
						|
            embed_dim=HIDDEN_DIM,
 | 
						|
            depth=NUM_LAYERS,
 | 
						|
            num_heads=NUM_HEADS,
 | 
						|
            dropout=DROPOUT,
 | 
						|
            num_classes=NUM_CLASSES
 | 
						|
        ).to(device)
 | 
						|
        
 | 
						|
        # 创建单模态模型 - WNB
 | 
						|
        wnb_model = SingleModalModel(
 | 
						|
            img_size=IMG_SIZE,
 | 
						|
            patch_size=16,
 | 
						|
            in_channels=3,
 | 
						|
            embed_dim=HIDDEN_DIM,
 | 
						|
            depth=NUM_LAYERS,
 | 
						|
            num_heads=NUM_HEADS,
 | 
						|
            dropout=DROPOUT,
 | 
						|
            num_classes=NUM_CLASSES
 | 
						|
        ).to(device)
 | 
						|
        
 | 
						|
        # 创建多模态模型
 | 
						|
        multi_modal_model = MultiModalFusionModel(
 | 
						|
            img_size=IMG_SIZE,
 | 
						|
            patch_size=16,
 | 
						|
            in_channels=3,
 | 
						|
            embed_dim=HIDDEN_DIM,
 | 
						|
            depth=NUM_LAYERS,
 | 
						|
            num_heads=NUM_HEADS,
 | 
						|
            dropout=DROPOUT,
 | 
						|
            num_classes=NUM_CLASSES
 | 
						|
        ).to(device)
 | 
						|
        
 | 
						|
        # 定义损失函数
 | 
						|
        criterion = nn.CrossEntropyLoss()
 | 
						|
        
 | 
						|
        # 训练DIFF单模态模型
 | 
						|
        print("训练DIFF散点图单模态模型...")
 | 
						|
        diff_optimizer = optim.AdamW(diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
 | 
						|
        diff_model, diff_history = train_single_modal_model(
 | 
						|
            model=diff_model,
 | 
						|
            train_loader=train_loader,
 | 
						|
            val_loader=val_loader,
 | 
						|
            criterion=criterion,
 | 
						|
            optimizer=diff_optimizer,
 | 
						|
            device=device,
 | 
						|
            num_epochs=args.epochs,
 | 
						|
            save_dir=MODEL_SAVE_DIR,
 | 
						|
            model_name='diff_only',
 | 
						|
            modal_key='diff_img'
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 训练WNB单模态模型
 | 
						|
        print("训练WNB散点图单模态模型...")
 | 
						|
        wnb_optimizer = optim.AdamW(wnb_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
 | 
						|
        wnb_model, wnb_history = train_single_modal_model(
 | 
						|
            model=wnb_model,
 | 
						|
            train_loader=train_loader,
 | 
						|
            val_loader=val_loader,
 | 
						|
            criterion=criterion,
 | 
						|
            optimizer=wnb_optimizer,
 | 
						|
            device=device,
 | 
						|
            num_epochs=args.epochs,
 | 
						|
            save_dir=MODEL_SAVE_DIR,
 | 
						|
            model_name='wnb_only',
 | 
						|
            modal_key='wnb_img'
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 训练多模态模型
 | 
						|
        print("训练多模态融合模型...")
 | 
						|
        multi_optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
 | 
						|
        multi_modal_model, multi_history = train_model(
 | 
						|
            model=multi_modal_model,
 | 
						|
            train_loader=train_loader,
 | 
						|
            val_loader=val_loader,
 | 
						|
            criterion=criterion,
 | 
						|
            optimizer=multi_optimizer,
 | 
						|
            device=device,
 | 
						|
            num_epochs=args.epochs,
 | 
						|
            save_dir=MODEL_SAVE_DIR,
 | 
						|
            model_name='multi_modal'
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 评估并比较模型
 | 
						|
        print("\n评估DIFF散点图单模态模型...")
 | 
						|
        diff_metrics, _, _, _ = evaluate_model(
 | 
						|
            model=diff_model,
 | 
						|
            data_loader=val_loader,
 | 
						|
            device=device,
 | 
						|
            class_names=['正常', '异常']
 | 
						|
        )
 | 
						|
        
 | 
						|
        print("\n评估WNB散点图单模态模型...")
 | 
						|
        wnb_metrics, _, _, _ = evaluate_model(
 | 
						|
            model=wnb_model,
 | 
						|
            data_loader=val_loader,
 | 
						|
            device=device,
 | 
						|
            class_names=['正常', '异常']
 | 
						|
        )
 | 
						|
        
 | 
						|
        print("\n评估多模态融合模型...")
 | 
						|
        multi_metrics, _, _, _ = evaluate_model(
 | 
						|
            model=multi_modal_model,
 | 
						|
            data_loader=val_loader,
 | 
						|
            device=device,
 | 
						|
            class_names=['正常', '异常']
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 比较不同模型的性能
 | 
						|
        compare_models(
 | 
						|
            [diff_metrics, wnb_metrics, multi_metrics],
 | 
						|
            ['DIFF散点图', 'WNB散点图', '多模态融合'],
 | 
						|
            save_path=os.path.join(args.output_dir, 'model_comparison.png')
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 可视化训练曲线
 | 
						|
        plt.figure(figsize=(12, 8))
 | 
						|
        
 | 
						|
        plt.subplot(2, 2, 1)
 | 
						|
        plt.plot(diff_history['train_losses'], label='DIFF训练')
 | 
						|
        plt.plot(wnb_history['train_losses'], label='WNB训练')
 | 
						|
        plt.plot(multi_history['train_losses'], label='多模态训练')
 | 
						|
        plt.title('训练损失')
 | 
						|
        plt.xlabel('Epoch')
 | 
						|
        plt.ylabel('损失')
 | 
						|
        plt.legend()
 | 
						|
        plt.grid(True)
 | 
						|
        
 | 
						|
        plt.subplot(2, 2, 2)
 | 
						|
        plt.plot(diff_history['val_losses'], label='DIFF验证')
 | 
						|
        plt.plot(wnb_history['val_losses'], label='WNB验证')
 | 
						|
        plt.plot(multi_history['val_losses'], label='多模态验证')
 | 
						|
        plt.title('验证损失')
 | 
						|
        plt.xlabel('Epoch')
 | 
						|
        plt.ylabel('损失')
 | 
						|
        plt.legend()
 | 
						|
        plt.grid(True)
 | 
						|
        
 | 
						|
        plt.subplot(2, 2, 3)
 | 
						|
        plt.plot(diff_history['train_accs'], label='DIFF训练')
 | 
						|
        plt.plot(wnb_history['train_accs'], label='WNB训练')
 | 
						|
        plt.plot(multi_history['train_accs'], label='多模态训练')
 | 
						|
        plt.title('训练准确率')
 | 
						|
        plt.xlabel('Epoch')
 | 
						|
        plt.ylabel('准确率')
 | 
						|
        plt.legend()
 | 
						|
        plt.grid(True)
 | 
						|
        
 | 
						|
        plt.subplot(2, 2, 4)
 | 
						|
        plt.plot(diff_history['val_accs'], label='DIFF验证')
 | 
						|
        plt.plot(wnb_history['val_accs'], label='WNB验证')
 | 
						|
        plt.plot(multi_history['val_accs'], label='多模态验证')
 | 
						|
        plt.title('验证准确率')
 | 
						|
        plt.xlabel('Epoch')
 | 
						|
        plt.ylabel('准确率')
 | 
						|
        plt.legend()
 | 
						|
        plt.grid(True)
 | 
						|
        
 | 
						|
        plt.tight_layout()
 | 
						|
        plt.savefig(os.path.join(args.output_dir, 'all_models_training_curves.png'))
 | 
						|
        plt.show()
 | 
						|
        
 | 
						|
        # 保存结果
 | 
						|
        save_results(diff_metrics, os.path.join(args.output_dir, 'diff_model_results.txt'))
 | 
						|
        save_results(wnb_metrics, os.path.join(args.output_dir, 'wnb_model_results.txt'))
 | 
						|
        save_results(multi_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt'))
 | 
						|
        
 | 
						|
    elif args.mode == 'evaluate':
 | 
						|
        # 加载预训练的多模态模型
 | 
						|
        print("加载预训练的多模态模型...")
 | 
						|
        model_path = os.path.join(MODEL_SAVE_DIR, 'multi_modal_best.pth')
 | 
						|
        
 | 
						|
        if not os.path.exists(model_path):
 | 
						|
            print(f"错误:找不到预训练模型 {model_path}")
 | 
						|
            return
 | 
						|
        
 | 
						|
        multi_modal_model = MultiModalFusionModel(
 | 
						|
            img_size=IMG_SIZE,
 | 
						|
            patch_size=16,
 | 
						|
            in_channels=3,
 | 
						|
            embed_dim=HIDDEN_DIM,
 | 
						|
            depth=NUM_LAYERS,
 | 
						|
            num_heads=NUM_HEADS,
 | 
						|
            dropout=DROPOUT,
 | 
						|
            num_classes=NUM_CLASSES
 | 
						|
        ).to(device)
 | 
						|
        
 | 
						|
        multi_modal_model.load_state_dict(torch.load(model_path))
 | 
						|
        
 | 
						|
        # 评估模型
 | 
						|
        print("评估多模态融合模型...")
 | 
						|
        multi_modal_metrics, labels, preds, probs = evaluate_model(
 | 
						|
            model=multi_modal_model,
 | 
						|
            data_loader=val_loader,
 | 
						|
            device=device,
 | 
						|
            class_names=['正常', '异常']
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 绘制ROC曲线
 | 
						|
        plot_roc_curve(
 | 
						|
            labels, 
 | 
						|
            probs,
 | 
						|
            save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 提取和可视化特征
 | 
						|
        extract_and_visualize_features(
 | 
						|
            model=multi_modal_model,
 | 
						|
            data_loader=val_loader,
 | 
						|
            device=device,
 | 
						|
            save_path=os.path.join(args.output_dir, 'feature_visualization.png')
 | 
						|
        )
 | 
						|
        
 | 
						|
        # 保存结果
 | 
						|
        save_results(
 | 
						|
            multi_modal_metrics,
 | 
						|
            os.path.join(args.output_dir, 'multi_modal_results.txt')
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    main() |