leukemia/main.py

349 lines
12 KiB
Python
Raw Permalink Normal View History

import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import argparse
import matplotlib.pyplot as plt
from config import *
from data_preprocessing.data_loader import load_data
from models.image_models import VisionTransformer
from models.fusion_model import MultiModalFusionModel, SingleModalModel
from training.train import train_model, train_single_modal_model
from training.utils import plot_training_curves, save_results
from evaluation.evaluate import evaluate_model, plot_roc_curve, extract_and_visualize_features, compare_models
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='白血病智能筛查系统')
parser.add_argument('--data_root', type=str, default=DATA_ROOT, help='数据根目录')
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='输出目录')
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='批大小')
parser.add_argument('--epochs', type=int, default=NUM_EPOCHS, help='训练轮数')
parser.add_argument('--lr', type=float, default=LEARNING_RATE, help='学习率')
parser.add_argument('--weight_decay', type=float, default=WEIGHT_DECAY, help='权重衰减')
parser.add_argument('--mode', type=str, choices=['train', 'evaluate', 'compare'], default='train', help='运行模式')
args = parser.parse_args()
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 加载数据
print("加载数据...")
train_loader, val_loader = load_data(
data_root=args.data_root,
img_size=IMG_SIZE,
batch_size=args.batch_size,
num_workers=NUM_WORKERS,
train_ratio=TRAIN_RATIO,
max_samples_per_class=MAX_SAMPLES_PER_CLASS,
normal_class=NORMAL_CLASS,
abnormal_classes=ABNORMAL_CLASSES
)
print(f"数据加载完成。训练集批次数: {len(train_loader)}, 验证集批次数: {len(val_loader)}")
# 设置设备
device = DEVICE
print(f"使用设备: {device}")
if args.mode == 'train':
# 创建多模态模型
print("创建多模态融合模型...")
multi_modal_model = MultiModalFusionModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# 训练多模态模型
print("开始训练多模态融合模型...")
multi_modal_model, multi_modal_history = train_model(
model=multi_modal_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='multi_modal'
)
# 可视化训练历史
plot_training_curves(
multi_modal_history['train_losses'],
multi_modal_history['val_losses'],
multi_modal_history['train_accs'],
multi_modal_history['val_accs'],
save_path=os.path.join(args.output_dir, 'multi_modal_training_curves.png')
)
# 评估多模态模型
print("\n评估多模态融合模型...")
multi_modal_metrics, labels, preds, probs = evaluate_model(
model=multi_modal_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
# 绘制ROC曲线
plot_roc_curve(
labels,
probs,
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
)
# 提取和可视化特征
extract_and_visualize_features(
model=multi_modal_model,
data_loader=val_loader,
device=device,
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
)
# 保存结果
save_results(
multi_modal_metrics,
os.path.join(args.output_dir, 'multi_modal_results.txt')
)
elif args.mode == 'compare':
print("创建模型进行比较...")
# 创建单模态模型 - DIFF
diff_model = SingleModalModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 创建单模态模型 - WNB
wnb_model = SingleModalModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 创建多模态模型
multi_modal_model = MultiModalFusionModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 训练DIFF单模态模型
print("训练DIFF散点图单模态模型...")
diff_optimizer = optim.AdamW(diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
diff_model, diff_history = train_single_modal_model(
model=diff_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=diff_optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='diff_only',
modal_key='diff_img'
)
# 训练WNB单模态模型
print("训练WNB散点图单模态模型...")
wnb_optimizer = optim.AdamW(wnb_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
wnb_model, wnb_history = train_single_modal_model(
model=wnb_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=wnb_optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='wnb_only',
modal_key='wnb_img'
)
# 训练多模态模型
print("训练多模态融合模型...")
multi_optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
multi_modal_model, multi_history = train_model(
model=multi_modal_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=multi_optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='multi_modal'
)
# 评估并比较模型
print("\n评估DIFF散点图单模态模型...")
diff_metrics, _, _, _ = evaluate_model(
model=diff_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
print("\n评估WNB散点图单模态模型...")
wnb_metrics, _, _, _ = evaluate_model(
model=wnb_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
print("\n评估多模态融合模型...")
multi_metrics, _, _, _ = evaluate_model(
model=multi_modal_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
# 比较不同模型的性能
compare_models(
[diff_metrics, wnb_metrics, multi_metrics],
['DIFF散点图', 'WNB散点图', '多模态融合'],
save_path=os.path.join(args.output_dir, 'model_comparison.png')
)
# 可视化训练曲线
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(diff_history['train_losses'], label='DIFF训练')
plt.plot(wnb_history['train_losses'], label='WNB训练')
plt.plot(multi_history['train_losses'], label='多模态训练')
plt.title('训练损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.subplot(2, 2, 2)
plt.plot(diff_history['val_losses'], label='DIFF验证')
plt.plot(wnb_history['val_losses'], label='WNB验证')
plt.plot(multi_history['val_losses'], label='多模态验证')
plt.title('验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.subplot(2, 2, 3)
plt.plot(diff_history['train_accs'], label='DIFF训练')
plt.plot(wnb_history['train_accs'], label='WNB训练')
plt.plot(multi_history['train_accs'], label='多模态训练')
plt.title('训练准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.subplot(2, 2, 4)
plt.plot(diff_history['val_accs'], label='DIFF验证')
plt.plot(wnb_history['val_accs'], label='WNB验证')
plt.plot(multi_history['val_accs'], label='多模态验证')
plt.title('验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(args.output_dir, 'all_models_training_curves.png'))
plt.show()
# 保存结果
save_results(diff_metrics, os.path.join(args.output_dir, 'diff_model_results.txt'))
save_results(wnb_metrics, os.path.join(args.output_dir, 'wnb_model_results.txt'))
save_results(multi_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt'))
elif args.mode == 'evaluate':
# 加载预训练的多模态模型
print("加载预训练的多模态模型...")
model_path = os.path.join(MODEL_SAVE_DIR, 'multi_modal_best.pth')
if not os.path.exists(model_path):
print(f"错误:找不到预训练模型 {model_path}")
return
multi_modal_model = MultiModalFusionModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
multi_modal_model.load_state_dict(torch.load(model_path))
# 评估模型
print("评估多模态融合模型...")
multi_modal_metrics, labels, preds, probs = evaluate_model(
model=multi_modal_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
# 绘制ROC曲线
plot_roc_curve(
labels,
probs,
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
)
# 提取和可视化特征
extract_and_visualize_features(
model=multi_modal_model,
data_loader=val_loader,
device=device,
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
)
# 保存结果
save_results(
multi_modal_metrics,
os.path.join(args.output_dir, 'multi_modal_results.txt')
)
if __name__ == "__main__":
main()