leukemia/main.py

349 lines
12 KiB
Python

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import argparse
import matplotlib.pyplot as plt
from config import *
from data_preprocessing.data_loader import load_data
from models.image_models import VisionTransformer
from models.fusion_model import MultiModalFusionModel, SingleModalModel
from training.train import train_model, train_single_modal_model
from training.utils import plot_training_curves, save_results
from evaluation.evaluate import evaluate_model, plot_roc_curve, extract_and_visualize_features, compare_models
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='白血病智能筛查系统')
parser.add_argument('--data_root', type=str, default=DATA_ROOT, help='数据根目录')
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='输出目录')
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='批大小')
parser.add_argument('--epochs', type=int, default=NUM_EPOCHS, help='训练轮数')
parser.add_argument('--lr', type=float, default=LEARNING_RATE, help='学习率')
parser.add_argument('--weight_decay', type=float, default=WEIGHT_DECAY, help='权重衰减')
parser.add_argument('--mode', type=str, choices=['train', 'evaluate', 'compare'], default='train', help='运行模式')
args = parser.parse_args()
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 加载数据
print("加载数据...")
train_loader, val_loader = load_data(
data_root=args.data_root,
img_size=IMG_SIZE,
batch_size=args.batch_size,
num_workers=NUM_WORKERS,
train_ratio=TRAIN_RATIO,
max_samples_per_class=MAX_SAMPLES_PER_CLASS,
normal_class=NORMAL_CLASS,
abnormal_classes=ABNORMAL_CLASSES
)
print(f"数据加载完成。训练集批次数: {len(train_loader)}, 验证集批次数: {len(val_loader)}")
# 设置设备
device = DEVICE
print(f"使用设备: {device}")
if args.mode == 'train':
# 创建多模态模型
print("创建多模态融合模型...")
multi_modal_model = MultiModalFusionModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# 训练多模态模型
print("开始训练多模态融合模型...")
multi_modal_model, multi_modal_history = train_model(
model=multi_modal_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='multi_modal'
)
# 可视化训练历史
plot_training_curves(
multi_modal_history['train_losses'],
multi_modal_history['val_losses'],
multi_modal_history['train_accs'],
multi_modal_history['val_accs'],
save_path=os.path.join(args.output_dir, 'multi_modal_training_curves.png')
)
# 评估多模态模型
print("\n评估多模态融合模型...")
multi_modal_metrics, labels, preds, probs = evaluate_model(
model=multi_modal_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
# 绘制ROC曲线
plot_roc_curve(
labels,
probs,
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
)
# 提取和可视化特征
extract_and_visualize_features(
model=multi_modal_model,
data_loader=val_loader,
device=device,
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
)
# 保存结果
save_results(
multi_modal_metrics,
os.path.join(args.output_dir, 'multi_modal_results.txt')
)
elif args.mode == 'compare':
print("创建模型进行比较...")
# 创建单模态模型 - DIFF
diff_model = SingleModalModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 创建单模态模型 - WNB
wnb_model = SingleModalModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 创建多模态模型
multi_modal_model = MultiModalFusionModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 训练DIFF单模态模型
print("训练DIFF散点图单模态模型...")
diff_optimizer = optim.AdamW(diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
diff_model, diff_history = train_single_modal_model(
model=diff_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=diff_optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='diff_only',
modal_key='diff_img'
)
# 训练WNB单模态模型
print("训练WNB散点图单模态模型...")
wnb_optimizer = optim.AdamW(wnb_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
wnb_model, wnb_history = train_single_modal_model(
model=wnb_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=wnb_optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='wnb_only',
modal_key='wnb_img'
)
# 训练多模态模型
print("训练多模态融合模型...")
multi_optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
multi_modal_model, multi_history = train_model(
model=multi_modal_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=multi_optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='multi_modal'
)
# 评估并比较模型
print("\n评估DIFF散点图单模态模型...")
diff_metrics, _, _, _ = evaluate_model(
model=diff_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
print("\n评估WNB散点图单模态模型...")
wnb_metrics, _, _, _ = evaluate_model(
model=wnb_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
print("\n评估多模态融合模型...")
multi_metrics, _, _, _ = evaluate_model(
model=multi_modal_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
# 比较不同模型的性能
compare_models(
[diff_metrics, wnb_metrics, multi_metrics],
['DIFF散点图', 'WNB散点图', '多模态融合'],
save_path=os.path.join(args.output_dir, 'model_comparison.png')
)
# 可视化训练曲线
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(diff_history['train_losses'], label='DIFF训练')
plt.plot(wnb_history['train_losses'], label='WNB训练')
plt.plot(multi_history['train_losses'], label='多模态训练')
plt.title('训练损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.subplot(2, 2, 2)
plt.plot(diff_history['val_losses'], label='DIFF验证')
plt.plot(wnb_history['val_losses'], label='WNB验证')
plt.plot(multi_history['val_losses'], label='多模态验证')
plt.title('验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.subplot(2, 2, 3)
plt.plot(diff_history['train_accs'], label='DIFF训练')
plt.plot(wnb_history['train_accs'], label='WNB训练')
plt.plot(multi_history['train_accs'], label='多模态训练')
plt.title('训练准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.subplot(2, 2, 4)
plt.plot(diff_history['val_accs'], label='DIFF验证')
plt.plot(wnb_history['val_accs'], label='WNB验证')
plt.plot(multi_history['val_accs'], label='多模态验证')
plt.title('验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(args.output_dir, 'all_models_training_curves.png'))
plt.show()
# 保存结果
save_results(diff_metrics, os.path.join(args.output_dir, 'diff_model_results.txt'))
save_results(wnb_metrics, os.path.join(args.output_dir, 'wnb_model_results.txt'))
save_results(multi_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt'))
elif args.mode == 'evaluate':
# 加载预训练的多模态模型
print("加载预训练的多模态模型...")
model_path = os.path.join(MODEL_SAVE_DIR, 'multi_modal_best.pth')
if not os.path.exists(model_path):
print(f"错误:找不到预训练模型 {model_path}")
return
multi_modal_model = MultiModalFusionModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
multi_modal_model.load_state_dict(torch.load(model_path))
# 评估模型
print("评估多模态融合模型...")
multi_modal_metrics, labels, preds, probs = evaluate_model(
model=multi_modal_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
# 绘制ROC曲线
plot_roc_curve(
labels,
probs,
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
)
# 提取和可视化特征
extract_and_visualize_features(
model=multi_modal_model,
data_loader=val_loader,
device=device,
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
)
# 保存结果
save_results(
multi_modal_metrics,
os.path.join(args.output_dir, 'multi_modal_results.txt')
)
if __name__ == "__main__":
main()