diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fcab143 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +data/ +_pycache_/ +output/ \ No newline at end of file diff --git a/__pycache__/config.cpython-312.pyc b/__pycache__/config.cpython-312.pyc new file mode 100644 index 0000000..edd9638 Binary files /dev/null and b/__pycache__/config.cpython-312.pyc differ diff --git a/__pycache__/config.cpython-38.pyc b/__pycache__/config.cpython-38.pyc new file mode 100644 index 0000000..7abdf56 Binary files /dev/null and b/__pycache__/config.cpython-38.pyc differ diff --git a/config.py b/config.py new file mode 100644 index 0000000..b114e1d --- /dev/null +++ b/config.py @@ -0,0 +1,36 @@ +import os +import torch + +# 路径配置 +DATA_ROOT = "./data" # 数据根目录,包含000,001等子文件夹 +OUTPUT_DIR = "./output" +MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models") + +# 确保目录存在 +os.makedirs(MODEL_SAVE_DIR, exist_ok=True) + +# 数据配置 +IMG_SIZE = 224 # 调整图像大小 +BATCH_SIZE = 32 +NUM_WORKERS = 4 +TRAIN_RATIO = 0.8 +VAL_RATIO = 0.2 + +# 样本数量控制 +MAX_SAMPLES_PER_CLASS = 1000 # 每个类别最多读取的样本数 +NORMAL_CLASS = "000" # 正常类别的文件夹名 +ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"] # 异常类别的文件夹名 + +# 模型配置 +HIDDEN_DIM = 768 +NUM_HEADS = 12 +NUM_LAYERS = 6 +DROPOUT = 0.1 +NUM_CLASSES = 2 # 二分类:正常 vs 异常 + +# 训练配置 +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +LEARNING_RATE = 1e-4 +WEIGHT_DECAY = 1e-5 +NUM_EPOCHS = 50 +EARLY_STOPPING_PATIENCE = 10 \ No newline at end of file diff --git a/data_preprocessing/__pycache__/data_loader.cpython-312.pyc b/data_preprocessing/__pycache__/data_loader.cpython-312.pyc new file mode 100644 index 0000000..508ba2c Binary files /dev/null and b/data_preprocessing/__pycache__/data_loader.cpython-312.pyc differ diff --git a/data_preprocessing/__pycache__/data_loader.cpython-38.pyc b/data_preprocessing/__pycache__/data_loader.cpython-38.pyc new file mode 100644 index 0000000..23e34a5 Binary files /dev/null and b/data_preprocessing/__pycache__/data_loader.cpython-38.pyc differ diff --git a/data_preprocessing/data_loader.py b/data_preprocessing/data_loader.py new file mode 100644 index 0000000..3f28be9 --- /dev/null +++ b/data_preprocessing/data_loader.py @@ -0,0 +1,171 @@ +import os +import glob +import random +import numpy as np +from PIL import Image +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from sklearn.model_selection import train_test_split + + +class LeukemiaDataset(Dataset): + def __init__(self, sample_paths, labels, transform=None): + """ + 初始化白血病数据集 + + Args: + sample_paths: 样本路径列表,每个元素是(diff_path, wnb_path)元组 + labels: 对应的标签列表 (0: 正常, 1: 异常) + transform: 图像转换操作 + """ + self.sample_paths = sample_paths + self.labels = labels + self.transform = transform + + def __len__(self): + return len(self.sample_paths) + + def __getitem__(self, idx): + diff_path, wnb_path = self.sample_paths[idx] + + # 从路径中提取样本ID + sample_id = os.path.basename(diff_path).replace('Diff.png', '') + + # 加载DIFF散点图 + diff_img = Image.open(diff_path).convert('RGB') + + # 加载WNB散点图 + wnb_img = Image.open(wnb_path).convert('RGB') + + if self.transform: + diff_img = self.transform(diff_img) + wnb_img = self.transform(wnb_img) + + label = self.labels[idx] + + return { + 'id': sample_id, + 'diff_img': diff_img, + 'wnb_img': wnb_img, + 'label': torch.tensor(label, dtype=torch.long) + } + + +def load_data(data_root, img_size=224, batch_size=32, num_workers=4, train_ratio=0.8, + max_samples_per_class=500, normal_class="000", abnormal_classes=None): + """ + 加载和预处理数据 + + Args: + data_root: 数据根目录,包含000,001等子文件夹 + img_size: 调整图像大小 + batch_size: 批量大小 + num_workers: 数据加载的工作线程数 + train_ratio: 训练集比例 + max_samples_per_class: 每个类别最多读取的样本数 + normal_class: 正常类别的文件夹名 + abnormal_classes: 异常类别的文件夹名列表 + + Returns: + train_loader: 训练数据加载器 + val_loader: 验证数据加载器 + """ + if abnormal_classes is None: + abnormal_classes = ["001", "010", "011", "100", "101", "110", "111"] + + # 图像转换 + transform = transforms.Compose([ + transforms.Resize((img_size, img_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + # 收集正常样本数据 + normal_samples = [] + normal_folder = os.path.join(data_root, normal_class) + diff_files = glob.glob(os.path.join(normal_folder, "*Diff.png")) + + # 限制正常样本数量 + if len(diff_files) > max_samples_per_class: + random.shuffle(diff_files) + diff_files = diff_files[:max_samples_per_class] + + for diff_file in diff_files: + sample_id = os.path.basename(diff_file).replace('Diff.png', '') + wnb_file = os.path.join(normal_folder, f"{sample_id}Wnb.png") + + # 确保WNB文件存在 + if os.path.exists(wnb_file): + normal_samples.append((diff_file, wnb_file)) + + # 收集异常样本数据 + abnormal_samples = [] + + # 计算每个异常类别分配的样本数 + # 如果max_samples_per_class=500,总共取500个异常样本,平均分配给各个异常类别 + samples_per_abnormal_class = max_samples_per_class // len(abnormal_classes) + + for abnormal_class in abnormal_classes: + abnormal_folder = os.path.join(data_root, abnormal_class) + diff_files = glob.glob(os.path.join(abnormal_folder, "*Diff.png")) + + # 限制每个异常类别的样本数量 + if len(diff_files) > samples_per_abnormal_class: + random.shuffle(diff_files) + diff_files = diff_files[:samples_per_abnormal_class] + + for diff_file in diff_files: + sample_id = os.path.basename(diff_file).replace('Diff.png', '') + wnb_file = os.path.join(abnormal_folder, f"{sample_id}Wnb.png") + + # 确保WNB文件存在 + if os.path.exists(wnb_file): + abnormal_samples.append((diff_file, wnb_file)) + + # 准备数据集 + all_samples = normal_samples + abnormal_samples + all_labels = [0] * len(normal_samples) + [1] * len(abnormal_samples) + + print(f"收集到的正常样本数: {len(normal_samples)}") + print(f"收集到的异常样本数: {len(abnormal_samples)}") + print(f"总样本数: {len(all_samples)}") + + # 划分训练集和验证集 + indices = np.arange(len(all_samples)) + train_indices, val_indices = train_test_split( + indices, + test_size=(1-train_ratio), + stratify=all_labels, # 确保训练集和验证集中类别比例一致 + random_state=42 + ) + + # 准备训练集和验证集 + train_samples = [all_samples[i] for i in train_indices] + train_labels = [all_labels[i] for i in train_indices] + + val_samples = [all_samples[i] for i in val_indices] + val_labels = [all_labels[i] for i in val_indices] + + # 创建数据集 + train_dataset = LeukemiaDataset(train_samples, train_labels, transform) + val_dataset = LeukemiaDataset(val_samples, val_labels, transform) + + # 创建数据加载器 + train_loader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_workers + ) + + val_loader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers + ) + + print(f"训练集样本数: {len(train_dataset)}, 验证集样本数: {len(val_dataset)}") + + return train_loader, val_loader \ No newline at end of file diff --git a/data_preprocessing/data_split.py b/data_preprocessing/data_split.py new file mode 100644 index 0000000..e69de29 diff --git a/evaluation/__pycache__/evaluate.cpython-312.pyc b/evaluation/__pycache__/evaluate.cpython-312.pyc new file mode 100644 index 0000000..30035a2 Binary files /dev/null and b/evaluation/__pycache__/evaluate.cpython-312.pyc differ diff --git a/evaluation/__pycache__/evaluate.cpython-38.pyc b/evaluation/__pycache__/evaluate.cpython-38.pyc new file mode 100644 index 0000000..00410d4 Binary files /dev/null and b/evaluation/__pycache__/evaluate.cpython-38.pyc differ diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py new file mode 100644 index 0000000..682959c --- /dev/null +++ b/evaluation/evaluate.py @@ -0,0 +1,185 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc +from sklearn.manifold import TSNE +import seaborn as sns +from training.utils import compute_metrics, plot_confusion_matrix, save_results + +def evaluate_model(model, data_loader, device, class_names=None): + """评估模型性能""" + model.eval() + + all_labels = [] + all_preds = [] + all_probs = [] + + with torch.no_grad(): + for batch in data_loader: + # 获取数据和标签 + diff_imgs = batch['diff_img'].to(device) + wnb_imgs = batch['wnb_img'].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(diff_imgs, wnb_imgs) + + # 获取预测和概率 + probs = torch.softmax(outputs, dim=1) + _, preds = torch.max(outputs, 1) + + # 保存结果 + all_labels.extend(labels.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + all_probs.extend(probs.cpu().numpy()) + + # 计算评估指标 + metrics = compute_metrics(all_labels, all_preds, all_probs) + + # 绘制混淆矩阵 + if class_names is None: + class_names = ['正常', '异常'] + plot_confusion_matrix(all_labels, all_preds, class_names) + + # 输出评估结果 + print(f"准确率: {metrics['accuracy']:.4f}") + print(f"精确率: {metrics['precision']:.4f}") + print(f"召回率: {metrics['recall']:.4f}") + print(f"F1值: {metrics['f1']:.4f}") + + return metrics, all_labels, all_preds, np.array(all_probs) + + +def plot_roc_curve(all_labels, all_probs, save_path=None): + """绘制ROC曲线""" + # 对于二分类问题,取阳性类(异常类别)的概率 + pos_probs = all_probs[:, 1] + + # 计算ROC曲线 + fpr, tpr, thresholds = roc_curve(all_labels, pos_probs) + roc_auc = auc(fpr, tpr) + + # 绘制ROC曲线 + plt.figure(figsize=(8, 6)) + plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})') + plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('假阳性率') + plt.ylabel('真阳性率') + plt.title('受试者工作特征(ROC)曲线') + plt.legend(loc='lower right') + plt.grid(True) + + if save_path: + plt.savefig(save_path) + + plt.show() + + return roc_auc + + +def extract_and_visualize_features(model, data_loader, device, save_path=None): + """提取特征并使用t-SNE可视化""" + model.eval() + + features_dict = { + 'diff': [], + 'wnb': [] + } + all_labels = [] + + with torch.no_grad(): + for batch in data_loader: + # 获取数据和标签 + diff_imgs = batch['diff_img'].to(device) + wnb_imgs = batch['wnb_img'].to(device) + labels = batch['label'].cpu().numpy() + + # 提取特征 + batch_features = model.extract_features(diff_imgs, wnb_imgs) + + # 保存特征和标签 + for modality in features_dict: + features_dict[modality].extend(batch_features[modality].cpu().numpy()) + all_labels.extend(labels) + + # 转换为NumPy数组 + all_labels = np.array(all_labels) + + # 可视化每个模态的特征 + plt.figure(figsize=(15, 5)) + + for i, (modality, features) in enumerate(features_dict.items()): + features = np.array(features) + + # 使用t-SNE降维 + tsne = TSNE(n_components=2, random_state=42) + features_tsne = tsne.fit_transform(features) + + # 绘制t-SNE结果 + plt.subplot(1, 3, i+1) + for label in np.unique(all_labels): + idx = all_labels == label + plt.scatter(features_tsne[idx, 0], features_tsne[idx, 1], + label=f"类别 {label}", alpha=0.7) + plt.title(f'{modality} 特征 t-SNE 可视化') + plt.xlabel('t-SNE 特征 1') + plt.ylabel('t-SNE 特征 2') + plt.legend() + plt.grid(True) + + # 将两种模态的特征连接起来进行可视化 + combined_features = np.concatenate([features_dict['diff'], features_dict['wnb']], axis=1) + tsne = TSNE(n_components=2, random_state=42) + combined_tsne = tsne.fit_transform(combined_features) + + plt.subplot(1, 3, 3) + for label in np.unique(all_labels): + idx = all_labels == label + plt.scatter(combined_tsne[idx, 0], combined_tsne[idx, 1], + label=f"类别 {label}", alpha=0.7) + plt.title('融合特征 t-SNE 可视化') + plt.xlabel('t-SNE 特征 1') + plt.ylabel('t-SNE 特征 2') + plt.legend() + plt.grid(True) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + + plt.show() + + +def compare_models(models_results, model_names, save_path=None): + """比较不同模型的性能""" + metrics = ['accuracy', 'precision', 'recall', 'f1'] + values = [] + + for result in models_results: + values.append([result[metric] for metric in metrics]) + + values = np.array(values) + + # 绘制条形图比较 + plt.figure(figsize=(10, 6)) + x = np.arange(len(metrics)) + width = 0.8 / len(models_results) + + for i, (name, vals) in enumerate(zip(model_names, values)): + plt.bar(x + i * width, vals, width, label=name) + + plt.xlabel('评估指标') + plt.ylabel('分数') + plt.title('不同模型性能比较') + plt.xticks(x + width * (len(models_results) - 1) / 2, metrics) + plt.ylim(0, 1.0) + plt.legend() + plt.grid(True, axis='y') + + if save_path: + plt.savefig(save_path) + + plt.show() \ No newline at end of file diff --git a/evaluation/visualization.py b/evaluation/visualization.py new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py new file mode 100644 index 0000000..7485c40 --- /dev/null +++ b/main.py @@ -0,0 +1,349 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import argparse +import matplotlib.pyplot as plt +from config import * + +from data_preprocessing.data_loader import load_data +from models.image_models import VisionTransformer +from models.fusion_model import MultiModalFusionModel, SingleModalModel +from training.train import train_model, train_single_modal_model +from training.utils import plot_training_curves, save_results +from evaluation.evaluate import evaluate_model, plot_roc_curve, extract_and_visualize_features, compare_models + + +def main(): + # 解析命令行参数 + parser = argparse.ArgumentParser(description='白血病智能筛查系统') + parser.add_argument('--data_root', type=str, default=DATA_ROOT, help='数据根目录') + parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='输出目录') + parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='批大小') + parser.add_argument('--epochs', type=int, default=NUM_EPOCHS, help='训练轮数') + parser.add_argument('--lr', type=float, default=LEARNING_RATE, help='学习率') + parser.add_argument('--weight_decay', type=float, default=WEIGHT_DECAY, help='权重衰减') + parser.add_argument('--mode', type=str, choices=['train', 'evaluate', 'compare'], default='train', help='运行模式') + args = parser.parse_args() + + # 确保输出目录存在 + os.makedirs(args.output_dir, exist_ok=True) + + # 加载数据 + print("加载数据...") + train_loader, val_loader = load_data( + data_root=args.data_root, + img_size=IMG_SIZE, + batch_size=args.batch_size, + num_workers=NUM_WORKERS, + train_ratio=TRAIN_RATIO, + max_samples_per_class=MAX_SAMPLES_PER_CLASS, + normal_class=NORMAL_CLASS, + abnormal_classes=ABNORMAL_CLASSES + ) + print(f"数据加载完成。训练集批次数: {len(train_loader)}, 验证集批次数: {len(val_loader)}") + + # 设置设备 + device = DEVICE + print(f"使用设备: {device}") + + if args.mode == 'train': + # 创建多模态模型 + print("创建多模态融合模型...") + multi_modal_model = MultiModalFusionModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + # 定义损失函数和优化器 + criterion = nn.CrossEntropyLoss() + optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + + # 训练多模态模型 + print("开始训练多模态融合模型...") + multi_modal_model, multi_modal_history = train_model( + model=multi_modal_model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + optimizer=optimizer, + device=device, + num_epochs=args.epochs, + save_dir=MODEL_SAVE_DIR, + model_name='multi_modal' + ) + + # 可视化训练历史 + plot_training_curves( + multi_modal_history['train_losses'], + multi_modal_history['val_losses'], + multi_modal_history['train_accs'], + multi_modal_history['val_accs'], + save_path=os.path.join(args.output_dir, 'multi_modal_training_curves.png') + ) + + # 评估多模态模型 + print("\n评估多模态融合模型...") + multi_modal_metrics, labels, preds, probs = evaluate_model( + model=multi_modal_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + # 绘制ROC曲线 + plot_roc_curve( + labels, + probs, + save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png') + ) + + # 提取和可视化特征 + extract_and_visualize_features( + model=multi_modal_model, + data_loader=val_loader, + device=device, + save_path=os.path.join(args.output_dir, 'feature_visualization.png') + ) + + # 保存结果 + save_results( + multi_modal_metrics, + os.path.join(args.output_dir, 'multi_modal_results.txt') + ) + + elif args.mode == 'compare': + print("创建模型进行比较...") + + # 创建单模态模型 - DIFF + diff_model = SingleModalModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + # 创建单模态模型 - WNB + wnb_model = SingleModalModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + # 创建多模态模型 + multi_modal_model = MultiModalFusionModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + # 定义损失函数 + criterion = nn.CrossEntropyLoss() + + # 训练DIFF单模态模型 + print("训练DIFF散点图单模态模型...") + diff_optimizer = optim.AdamW(diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + diff_model, diff_history = train_single_modal_model( + model=diff_model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + optimizer=diff_optimizer, + device=device, + num_epochs=args.epochs, + save_dir=MODEL_SAVE_DIR, + model_name='diff_only', + modal_key='diff_img' + ) + + # 训练WNB单模态模型 + print("训练WNB散点图单模态模型...") + wnb_optimizer = optim.AdamW(wnb_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + wnb_model, wnb_history = train_single_modal_model( + model=wnb_model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + optimizer=wnb_optimizer, + device=device, + num_epochs=args.epochs, + save_dir=MODEL_SAVE_DIR, + model_name='wnb_only', + modal_key='wnb_img' + ) + + # 训练多模态模型 + print("训练多模态融合模型...") + multi_optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) + multi_modal_model, multi_history = train_model( + model=multi_modal_model, + train_loader=train_loader, + val_loader=val_loader, + criterion=criterion, + optimizer=multi_optimizer, + device=device, + num_epochs=args.epochs, + save_dir=MODEL_SAVE_DIR, + model_name='multi_modal' + ) + + # 评估并比较模型 + print("\n评估DIFF散点图单模态模型...") + diff_metrics, _, _, _ = evaluate_model( + model=diff_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + print("\n评估WNB散点图单模态模型...") + wnb_metrics, _, _, _ = evaluate_model( + model=wnb_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + print("\n评估多模态融合模型...") + multi_metrics, _, _, _ = evaluate_model( + model=multi_modal_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + # 比较不同模型的性能 + compare_models( + [diff_metrics, wnb_metrics, multi_metrics], + ['DIFF散点图', 'WNB散点图', '多模态融合'], + save_path=os.path.join(args.output_dir, 'model_comparison.png') + ) + + # 可视化训练曲线 + plt.figure(figsize=(12, 8)) + + plt.subplot(2, 2, 1) + plt.plot(diff_history['train_losses'], label='DIFF训练') + plt.plot(wnb_history['train_losses'], label='WNB训练') + plt.plot(multi_history['train_losses'], label='多模态训练') + plt.title('训练损失') + plt.xlabel('Epoch') + plt.ylabel('损失') + plt.legend() + plt.grid(True) + + plt.subplot(2, 2, 2) + plt.plot(diff_history['val_losses'], label='DIFF验证') + plt.plot(wnb_history['val_losses'], label='WNB验证') + plt.plot(multi_history['val_losses'], label='多模态验证') + plt.title('验证损失') + plt.xlabel('Epoch') + plt.ylabel('损失') + plt.legend() + plt.grid(True) + + plt.subplot(2, 2, 3) + plt.plot(diff_history['train_accs'], label='DIFF训练') + plt.plot(wnb_history['train_accs'], label='WNB训练') + plt.plot(multi_history['train_accs'], label='多模态训练') + plt.title('训练准确率') + plt.xlabel('Epoch') + plt.ylabel('准确率') + plt.legend() + plt.grid(True) + + plt.subplot(2, 2, 4) + plt.plot(diff_history['val_accs'], label='DIFF验证') + plt.plot(wnb_history['val_accs'], label='WNB验证') + plt.plot(multi_history['val_accs'], label='多模态验证') + plt.title('验证准确率') + plt.xlabel('Epoch') + plt.ylabel('准确率') + plt.legend() + plt.grid(True) + + plt.tight_layout() + plt.savefig(os.path.join(args.output_dir, 'all_models_training_curves.png')) + plt.show() + + # 保存结果 + save_results(diff_metrics, os.path.join(args.output_dir, 'diff_model_results.txt')) + save_results(wnb_metrics, os.path.join(args.output_dir, 'wnb_model_results.txt')) + save_results(multi_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt')) + + elif args.mode == 'evaluate': + # 加载预训练的多模态模型 + print("加载预训练的多模态模型...") + model_path = os.path.join(MODEL_SAVE_DIR, 'multi_modal_best.pth') + + if not os.path.exists(model_path): + print(f"错误:找不到预训练模型 {model_path}") + return + + multi_modal_model = MultiModalFusionModel( + img_size=IMG_SIZE, + patch_size=16, + in_channels=3, + embed_dim=HIDDEN_DIM, + depth=NUM_LAYERS, + num_heads=NUM_HEADS, + dropout=DROPOUT, + num_classes=NUM_CLASSES + ).to(device) + + multi_modal_model.load_state_dict(torch.load(model_path)) + + # 评估模型 + print("评估多模态融合模型...") + multi_modal_metrics, labels, preds, probs = evaluate_model( + model=multi_modal_model, + data_loader=val_loader, + device=device, + class_names=['正常', '异常'] + ) + + # 绘制ROC曲线 + plot_roc_curve( + labels, + probs, + save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png') + ) + + # 提取和可视化特征 + extract_and_visualize_features( + model=multi_modal_model, + data_loader=val_loader, + device=device, + save_path=os.path.join(args.output_dir, 'feature_visualization.png') + ) + + # 保存结果 + save_results( + multi_modal_metrics, + os.path.join(args.output_dir, 'multi_modal_results.txt') + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/models/__pycache__/fusion_model.cpython-312.pyc b/models/__pycache__/fusion_model.cpython-312.pyc new file mode 100644 index 0000000..c1d44b3 Binary files /dev/null and b/models/__pycache__/fusion_model.cpython-312.pyc differ diff --git a/models/__pycache__/fusion_model.cpython-38.pyc b/models/__pycache__/fusion_model.cpython-38.pyc new file mode 100644 index 0000000..c956271 Binary files /dev/null and b/models/__pycache__/fusion_model.cpython-38.pyc differ diff --git a/models/__pycache__/image_models.cpython-312.pyc b/models/__pycache__/image_models.cpython-312.pyc new file mode 100644 index 0000000..ef23ab2 Binary files /dev/null and b/models/__pycache__/image_models.cpython-312.pyc differ diff --git a/models/__pycache__/image_models.cpython-38.pyc b/models/__pycache__/image_models.cpython-38.pyc new file mode 100644 index 0000000..2ac8e52 Binary files /dev/null and b/models/__pycache__/image_models.cpython-38.pyc differ diff --git a/models/fusion_model.py b/models/fusion_model.py new file mode 100644 index 0000000..10d1043 --- /dev/null +++ b/models/fusion_model.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from models.image_models import VisionTransformer + +class MultiModalFusionModel(nn.Module): + """多模态融合模型,融合DIFF和WNB散点图的特征""" + def __init__(self, img_size=224, patch_size=16, in_channels=3, + embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2): + super().__init__() + + # DIFF散点图特征提取器 + self.diff_encoder = VisionTransformer( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + dropout=dropout + ) + + # WNB散点图特征提取器 + self.wnb_encoder = VisionTransformer( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + dropout=dropout + ) + + # 特征融合层 + self.fusion = nn.Sequential( + nn.Linear(embed_dim * 2, embed_dim), + nn.LayerNorm(embed_dim), + nn.GELU(), + nn.Dropout(dropout) + ) + + # 分类头 + self.classifier = nn.Linear(embed_dim, num_classes) + + def forward(self, diff_img, wnb_img): + """ + 前向传播 + + Args: + diff_img: DIFF散点图 [B, C, H, W] + wnb_img: WNB散点图 [B, C, H, W] + + Returns: + logits: 分类logits [B, num_classes] + """ + # 提取特征 + diff_features = self.diff_encoder(diff_img) # [B, E] + wnb_features = self.wnb_encoder(wnb_img) # [B, E] + + # 特征融合 + combined_features = torch.cat([diff_features, wnb_features], dim=1) # [B, 2*E] + fused_features = self.fusion(combined_features) # [B, E] + + # 分类 + logits = self.classifier(fused_features) # [B, num_classes] + + return logits + + def extract_features(self, diff_img, wnb_img): + """提取各个模态的特征,用于分析""" + diff_features = self.diff_encoder(diff_img) + wnb_features = self.wnb_encoder(wnb_img) + + return { + 'diff': diff_features, + 'wnb': wnb_features + } + + +class SingleModalModel(nn.Module): + """单模态模型,用于对比实验""" + def __init__(self, img_size=224, patch_size=16, in_channels=3, + embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2): + super().__init__() + + # 图像特征提取器 + self.encoder = VisionTransformer( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=embed_dim, + depth=depth, + num_heads=num_heads, + dropout=dropout + ) + + # 分类头 + self.classifier = nn.Linear(embed_dim, num_classes) + + def forward(self, img): + """ + 前向传播 + + Args: + img: 输入图像 [B, C, H, W] + + Returns: + logits: 分类logits [B, num_classes] + """ + # 提取特征 + features = self.encoder(img) # [B, E] + + # 分类 + logits = self.classifier(features) # [B, num_classes] + + return logits \ No newline at end of file diff --git a/models/image_models.py b/models/image_models.py new file mode 100644 index 0000000..e17c749 --- /dev/null +++ b/models/image_models.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn +from torch.nn import TransformerEncoder, TransformerEncoderLayer +import torch.nn.functional as F +import torchvision.models as models + +class PatchEmbedding(nn.Module): + """将图像分割为patch并进行embedding""" + def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.n_patches = (img_size // patch_size) ** 2 + + self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + # x: [B, C, H, W] + batch_size = x.shape[0] + x = self.proj(x) # [B, E, H/P, W/P] + x = x.flatten(2) # [B, E, (H/P)*(W/P)] + x = x.transpose(1, 2) # [B, (H/P)*(W/P), E] + return x + + +class VisionTransformer(nn.Module): + """基于Transformer的图像特征提取模型""" + def __init__(self, img_size=224, patch_size=16, in_channels=3, + embed_dim=768, depth=6, num_heads=12, dropout=0.1): + super().__init__() + + # Patch Embedding + self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) + self.n_patches = self.patch_embed.n_patches + + # Position Embedding + self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + + # Transformer Encoder + encoder_layer = TransformerEncoderLayer( + d_model=embed_dim, + nhead=num_heads, + dim_feedforward=embed_dim * 4, + dropout=dropout, + activation='gelu', + batch_first=True + ) + self.transformer = TransformerEncoder(encoder_layer, num_layers=depth) + + # 层归一化 + self.norm = nn.LayerNorm(embed_dim) + + # 初始化 + nn.init.trunc_normal_(self.pos_embed, std=0.02) + nn.init.trunc_normal_(self.cls_token, std=0.02) + + def forward(self, x): + # x: [B, C, H, W] + batch_size = x.shape[0] + + # Patch Embedding: [B, N, E] + x = self.patch_embed(x) + + # 添加CLS token + cls_token = self.cls_token.expand(batch_size, -1, -1) # [B, 1, E] + x = torch.cat([cls_token, x], dim=1) # [B, N+1, E] + + # 添加Position Embedding + x = x + self.pos_embed + + # Transformer Encoder + x = self.transformer(x) + + # 提取CLS token作为整个图像的特征 + x = x[:, 0] # [B, E] + + return x \ No newline at end of file diff --git a/training/__pycache__/train.cpython-312.pyc b/training/__pycache__/train.cpython-312.pyc new file mode 100644 index 0000000..0bde064 Binary files /dev/null and b/training/__pycache__/train.cpython-312.pyc differ diff --git a/training/__pycache__/train.cpython-38.pyc b/training/__pycache__/train.cpython-38.pyc new file mode 100644 index 0000000..41ece40 Binary files /dev/null and b/training/__pycache__/train.cpython-38.pyc differ diff --git a/training/__pycache__/utils.cpython-312.pyc b/training/__pycache__/utils.cpython-312.pyc new file mode 100644 index 0000000..ea34919 Binary files /dev/null and b/training/__pycache__/utils.cpython-312.pyc differ diff --git a/training/__pycache__/utils.cpython-38.pyc b/training/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000..6d36b74 Binary files /dev/null and b/training/__pycache__/utils.cpython-38.pyc differ diff --git a/training/train.py b/training/train.py new file mode 100644 index 0000000..64f7802 --- /dev/null +++ b/training/train.py @@ -0,0 +1,312 @@ +import torch +import torch.nn as nn +import torch.optim as optim +import numpy as np +import os +from tqdm import tqdm +import time +from training.utils import AverageMeter, EarlyStopping, plot_training_curves, compute_metrics + +def train_epoch(model, train_loader, criterion, optimizer, device): + """训练一个epoch""" + model.train() + losses = AverageMeter() + acc = AverageMeter() + + # 进度条 + pbar = tqdm(train_loader, desc='训练') + + for batch in pbar: + # 获取数据和标签 + diff_imgs = batch['diff_img'].to(device) + wnb_imgs = batch['wnb_img'].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(diff_imgs, wnb_imgs) + + # 计算损失 + loss = criterion(outputs, labels) + + # 反向传播和优化 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # 计算准确率 + _, preds = torch.max(outputs, 1) + batch_acc = (preds == labels).float().mean() + + # 更新统计 + losses.update(loss.item(), labels.size(0)) + acc.update(batch_acc.item(), labels.size(0)) + + # 更新进度条 + pbar.set_postfix({ + 'loss': losses.avg, + 'acc': acc.avg + }) + + return losses.avg, acc.avg + + +def validate(model, val_loader, criterion, device): + """验证模型""" + model.eval() + losses = AverageMeter() + acc = AverageMeter() + + all_labels = [] + all_preds = [] + all_probs = [] + + with torch.no_grad(): + for batch in tqdm(val_loader, desc='验证'): + # 获取数据和标签 + diff_imgs = batch['diff_img'].to(device) + wnb_imgs = batch['wnb_img'].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(diff_imgs, wnb_imgs) + + # 计算损失 + loss = criterion(outputs, labels) + + # 计算准确率 + probs = torch.softmax(outputs, dim=1) + _, preds = torch.max(outputs, 1) + batch_acc = (preds == labels).float().mean() + + # 更新统计 + losses.update(loss.item(), labels.size(0)) + acc.update(batch_acc.item(), labels.size(0)) + + # 保存预测结果用于计算指标 + all_labels.extend(labels.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + all_probs.extend(probs.cpu().numpy()) + + # 计算其他评估指标 + metrics = compute_metrics(all_labels, all_preds, all_probs) + metrics['loss'] = losses.avg + metrics['accuracy'] = acc.avg + + return losses.avg, acc.avg, metrics + + +def train_model(model, train_loader, val_loader, criterion, optimizer, device, + num_epochs=50, save_dir='./models', model_name='model'): + """训练模型""" + # 保存最佳模型的路径 + if not os.path.exists(save_dir): + os.makedirs(save_dir) + best_model_path = os.path.join(save_dir, f'{model_name}_best.pth') + + # 初始化早停 + early_stopping = EarlyStopping(patience=10, path=best_model_path) + + # 跟踪训练历史 + train_losses = [] + val_losses = [] + train_accs = [] + val_accs = [] + best_val_acc = 0.0 + + # 开始训练 + start_time = time.time() + + for epoch in range(num_epochs): + print(f'\nEpoch {epoch+1}/{num_epochs}') + print('-' * 20) + + # 训练 + train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) + train_losses.append(train_loss) + train_accs.append(train_acc) + + # 验证 + val_loss, val_acc, val_metrics = validate(model, val_loader, criterion, device) + val_losses.append(val_loss) + val_accs.append(val_acc) + + # 打印当前epoch的结果 + print(f'训练损失: {train_loss:.4f} 训练准确率: {train_acc:.4f}') + print(f'验证损失: {val_loss:.4f} 验证准确率: {val_acc:.4f}') + print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}') + + # 检查是否为最佳验证准确率 + if val_acc > best_val_acc: + best_val_acc = val_acc + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'val_acc': val_acc, + 'val_metrics': val_metrics + }, os.path.join(save_dir, f'{model_name}_best_acc.pth')) + print(f'保存新的最佳模型,验证准确率: {val_acc:.4f}') + + # 早停检查 + early_stopping(val_loss, model) + if early_stopping.early_stop: + print(f"早停! 在第 {epoch+1} 个Epoch停止训练") + break + + # 计算总训练时间 + total_time = time.time() - start_time + print(f'训练完成! 总用时: {total_time/60:.2f} 分钟') + + # 加载最佳模型 + model.load_state_dict(torch.load(best_model_path)) + + return model, { + 'train_losses': train_losses, + 'val_losses': val_losses, + 'train_accs': train_accs, + 'val_accs': val_accs, + 'best_val_acc': best_val_acc, + 'total_time': total_time + } + + +def train_single_modal_model(model, train_loader, val_loader, criterion, optimizer, device, + num_epochs=50, save_dir='./models', model_name='single_modal', modal_key='diff_img'): + """训练单模态模型""" + # 保存最佳模型的路径 + if not os.path.exists(save_dir): + os.makedirs(save_dir) + best_model_path = os.path.join(save_dir, f'{model_name}_best.pth') + + # 初始化早停 + early_stopping = EarlyStopping(patience=10, path=best_model_path) + + # 跟踪训练历史 + train_losses = [] + val_losses = [] + train_accs = [] + val_accs = [] + best_val_acc = 0.0 + + # 开始训练 + start_time = time.time() + + for epoch in range(num_epochs): + print(f'\nEpoch {epoch+1}/{num_epochs}') + print('-' * 20) + + # 训练 + model.train() + train_loss = AverageMeter() + train_acc = AverageMeter() + + pbar = tqdm(train_loader, desc='训练') + for batch in pbar: + # 获取数据和标签 + imgs = batch[modal_key].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(imgs) + + # 计算损失 + loss = criterion(outputs, labels) + + # 反向传播和优化 + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # 计算准确率 + _, preds = torch.max(outputs, 1) + batch_acc = (preds == labels).float().mean() + + # 更新统计 + train_loss.update(loss.item(), labels.size(0)) + train_acc.update(batch_acc.item(), labels.size(0)) + + # 更新进度条 + pbar.set_postfix({ + 'loss': train_loss.avg, + 'acc': train_acc.avg + }) + + train_losses.append(train_loss.avg) + train_accs.append(train_acc.avg) + + # 验证 + model.eval() + val_loss = AverageMeter() + val_acc = AverageMeter() + + all_labels = [] + all_preds = [] + + with torch.no_grad(): + for batch in tqdm(val_loader, desc='验证'): + # 获取数据和标签 + imgs = batch[modal_key].to(device) + labels = batch['label'].to(device) + + # 前向传播 + outputs = model(imgs) + + # 计算损失 + loss = criterion(outputs, labels) + + # 计算准确率 + _, preds = torch.max(outputs, 1) + batch_acc = (preds == labels).float().mean() + + # 更新统计 + val_loss.update(loss.item(), labels.size(0)) + val_acc.update(batch_acc.item(), labels.size(0)) + + # 保存预测结果用于计算指标 + all_labels.extend(labels.cpu().numpy()) + all_preds.extend(preds.cpu().numpy()) + + val_losses.append(val_loss.avg) + val_accs.append(val_acc.avg) + + # 计算其他评估指标 + val_metrics = compute_metrics(all_labels, all_preds) + + # 打印当前epoch的结果 + print(f'训练损失: {train_loss.avg:.4f} 训练准确率: {train_acc.avg:.4f}') + print(f'验证损失: {val_loss.avg:.4f} 验证准确率: {val_acc.avg:.4f}') + print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}') + + # 检查是否为最佳验证准确率 + if val_acc.avg > best_val_acc: + best_val_acc = val_acc.avg + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'val_acc': val_acc.avg, + 'val_metrics': val_metrics + }, os.path.join(save_dir, f'{model_name}_best_acc.pth')) + print(f'保存新的最佳模型,验证准确率: {val_acc.avg:.4f}') + + # 早停检查 + early_stopping(val_loss.avg, model) + if early_stopping.early_stop: + print(f"早停! 在第 {epoch+1} 个Epoch停止训练") + break + + # 计算总训练时间 + total_time = time.time() - start_time + print(f'训练完成! 总用时: {total_time/60:.2f} 分钟') + + # 加载最佳模型 + model.load_state_dict(torch.load(best_model_path)) + + return model, { + 'train_losses': train_losses, + 'val_losses': val_losses, + 'train_accs': train_accs, + 'val_accs': val_accs, + 'best_val_acc': best_val_acc, + 'total_time': total_time + } \ No newline at end of file diff --git a/training/utils.py b/training/utils.py new file mode 100644 index 0000000..f74c572 --- /dev/null +++ b/training/utils.py @@ -0,0 +1,126 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix +import seaborn as sns +import os + +class AverageMeter: + """跟踪平均值和当前值""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class EarlyStopping: + """提前停止训练,避免过拟合""" + def __init__(self, patience=7, delta=0, path='checkpoint.pt'): + self.patience = patience + self.delta = delta + self.path = path + self.counter = 0 + self.best_score = None + self.early_stop = False + + def __call__(self, val_loss, model): + score = -val_loss + + if self.best_score is None: + self.best_score = score + self.save_checkpoint(val_loss, model) + elif score < self.best_score + self.delta: + self.counter += 1 + print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.save_checkpoint(val_loss, model) + self.counter = 0 + + def save_checkpoint(self, val_loss, model): + torch.save(model.state_dict(), self.path) + print(f'验证损失降低 ({self.best_score:.6f} --> {-val_loss:.6f}). 保存模型...') + + +def plot_training_curves(train_losses, val_losses, train_accs, val_accs, save_path=None): + """绘制训练和验证的损失与准确率曲线""" + plt.figure(figsize=(12, 5)) + + plt.subplot(1, 2, 1) + plt.plot(train_losses, label='训练损失') + plt.plot(val_losses, label='验证损失') + plt.title('损失曲线') + plt.xlabel('Epoch') + plt.ylabel('损失') + plt.legend() + plt.grid(True) + + plt.subplot(1, 2, 2) + plt.plot(train_accs, label='训练准确率') + plt.plot(val_accs, label='验证准确率') + plt.title('准确率曲线') + plt.xlabel('Epoch') + plt.ylabel('准确率') + plt.legend() + plt.grid(True) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + + plt.show() + + +def plot_confusion_matrix(y_true, y_pred, class_names=None, save_path=None): + """绘制混淆矩阵""" + cm = confusion_matrix(y_true, y_pred) + + plt.figure(figsize=(8, 6)) + sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', + xticklabels=class_names if class_names else "auto", + yticklabels=class_names if class_names else "auto") + plt.title('混淆矩阵') + plt.xlabel('预测标签') + plt.ylabel('真实标签') + + if save_path: + plt.savefig(save_path) + + plt.show() + + +def compute_metrics(y_true, y_pred, y_proba=None): + """计算各种评估指标""" + accuracy = accuracy_score(y_true, y_pred) + precision = precision_score(y_true, y_pred, average='binary', zero_division=0) + recall = recall_score(y_true, y_pred, average='binary') + f1 = f1_score(y_true, y_pred, average='binary') + + metrics = { + 'accuracy': accuracy, + 'precision': precision, + 'recall': recall, + 'f1': f1 + } + + return metrics + + +def save_results(results, filename): + """保存结果到文本文件""" + with open(filename, 'w') as f: + for key, value in results.items(): + f.write(f"{key}: {value}\n") \ No newline at end of file