已跑通一次,准确率85%,500样本量

This commit is contained in:
lotus 2025-04-17 16:26:35 +08:00
parent 2e859eea14
commit d1ab4a1f19
25 changed files with 1376 additions and 0 deletions

3
.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
data/
_pycache_/
output/

Binary file not shown.

Binary file not shown.

36
config.py Normal file
View File

@ -0,0 +1,36 @@
import os
import torch
# 路径配置
DATA_ROOT = "./data" # 数据根目录包含000,001等子文件夹
OUTPUT_DIR = "./output"
MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models")
# 确保目录存在
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
# 数据配置
IMG_SIZE = 224 # 调整图像大小
BATCH_SIZE = 32
NUM_WORKERS = 4
TRAIN_RATIO = 0.8
VAL_RATIO = 0.2
# 样本数量控制
MAX_SAMPLES_PER_CLASS = 1000 # 每个类别最多读取的样本数
NORMAL_CLASS = "000" # 正常类别的文件夹名
ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"] # 异常类别的文件夹名
# 模型配置
HIDDEN_DIM = 768
NUM_HEADS = 12
NUM_LAYERS = 6
DROPOUT = 0.1
NUM_CLASSES = 2 # 二分类:正常 vs 异常
# 训练配置
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
NUM_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10

View File

@ -0,0 +1,171 @@
import os
import glob
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
class LeukemiaDataset(Dataset):
def __init__(self, sample_paths, labels, transform=None):
"""
初始化白血病数据集
Args:
sample_paths: 样本路径列表每个元素是(diff_path, wnb_path)元组
labels: 对应的标签列表 (0: 正常, 1: 异常)
transform: 图像转换操作
"""
self.sample_paths = sample_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.sample_paths)
def __getitem__(self, idx):
diff_path, wnb_path = self.sample_paths[idx]
# 从路径中提取样本ID
sample_id = os.path.basename(diff_path).replace('Diff.png', '')
# 加载DIFF散点图
diff_img = Image.open(diff_path).convert('RGB')
# 加载WNB散点图
wnb_img = Image.open(wnb_path).convert('RGB')
if self.transform:
diff_img = self.transform(diff_img)
wnb_img = self.transform(wnb_img)
label = self.labels[idx]
return {
'id': sample_id,
'diff_img': diff_img,
'wnb_img': wnb_img,
'label': torch.tensor(label, dtype=torch.long)
}
def load_data(data_root, img_size=224, batch_size=32, num_workers=4, train_ratio=0.8,
max_samples_per_class=500, normal_class="000", abnormal_classes=None):
"""
加载和预处理数据
Args:
data_root: 数据根目录包含000,001等子文件夹
img_size: 调整图像大小
batch_size: 批量大小
num_workers: 数据加载的工作线程数
train_ratio: 训练集比例
max_samples_per_class: 每个类别最多读取的样本数
normal_class: 正常类别的文件夹名
abnormal_classes: 异常类别的文件夹名列表
Returns:
train_loader: 训练数据加载器
val_loader: 验证数据加载器
"""
if abnormal_classes is None:
abnormal_classes = ["001", "010", "011", "100", "101", "110", "111"]
# 图像转换
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 收集正常样本数据
normal_samples = []
normal_folder = os.path.join(data_root, normal_class)
diff_files = glob.glob(os.path.join(normal_folder, "*Diff.png"))
# 限制正常样本数量
if len(diff_files) > max_samples_per_class:
random.shuffle(diff_files)
diff_files = diff_files[:max_samples_per_class]
for diff_file in diff_files:
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
wnb_file = os.path.join(normal_folder, f"{sample_id}Wnb.png")
# 确保WNB文件存在
if os.path.exists(wnb_file):
normal_samples.append((diff_file, wnb_file))
# 收集异常样本数据
abnormal_samples = []
# 计算每个异常类别分配的样本数
# 如果max_samples_per_class=500总共取500个异常样本平均分配给各个异常类别
samples_per_abnormal_class = max_samples_per_class // len(abnormal_classes)
for abnormal_class in abnormal_classes:
abnormal_folder = os.path.join(data_root, abnormal_class)
diff_files = glob.glob(os.path.join(abnormal_folder, "*Diff.png"))
# 限制每个异常类别的样本数量
if len(diff_files) > samples_per_abnormal_class:
random.shuffle(diff_files)
diff_files = diff_files[:samples_per_abnormal_class]
for diff_file in diff_files:
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
wnb_file = os.path.join(abnormal_folder, f"{sample_id}Wnb.png")
# 确保WNB文件存在
if os.path.exists(wnb_file):
abnormal_samples.append((diff_file, wnb_file))
# 准备数据集
all_samples = normal_samples + abnormal_samples
all_labels = [0] * len(normal_samples) + [1] * len(abnormal_samples)
print(f"收集到的正常样本数: {len(normal_samples)}")
print(f"收集到的异常样本数: {len(abnormal_samples)}")
print(f"总样本数: {len(all_samples)}")
# 划分训练集和验证集
indices = np.arange(len(all_samples))
train_indices, val_indices = train_test_split(
indices,
test_size=(1-train_ratio),
stratify=all_labels, # 确保训练集和验证集中类别比例一致
random_state=42
)
# 准备训练集和验证集
train_samples = [all_samples[i] for i in train_indices]
train_labels = [all_labels[i] for i in train_indices]
val_samples = [all_samples[i] for i in val_indices]
val_labels = [all_labels[i] for i in val_indices]
# 创建数据集
train_dataset = LeukemiaDataset(train_samples, train_labels, transform)
val_dataset = LeukemiaDataset(val_samples, val_labels, transform)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
print(f"训练集样本数: {len(train_dataset)}, 验证集样本数: {len(val_dataset)}")
return train_loader, val_loader

View File

Binary file not shown.

Binary file not shown.

185
evaluation/evaluate.py Normal file
View File

@ -0,0 +1,185 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
from sklearn.manifold import TSNE
import seaborn as sns
from training.utils import compute_metrics, plot_confusion_matrix, save_results
def evaluate_model(model, data_loader, device, class_names=None):
"""评估模型性能"""
model.eval()
all_labels = []
all_preds = []
all_probs = []
with torch.no_grad():
for batch in data_loader:
# 获取数据和标签
diff_imgs = batch['diff_img'].to(device)
wnb_imgs = batch['wnb_img'].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(diff_imgs, wnb_imgs)
# 获取预测和概率
probs = torch.softmax(outputs, dim=1)
_, preds = torch.max(outputs, 1)
# 保存结果
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# 计算评估指标
metrics = compute_metrics(all_labels, all_preds, all_probs)
# 绘制混淆矩阵
if class_names is None:
class_names = ['正常', '异常']
plot_confusion_matrix(all_labels, all_preds, class_names)
# 输出评估结果
print(f"准确率: {metrics['accuracy']:.4f}")
print(f"精确率: {metrics['precision']:.4f}")
print(f"召回率: {metrics['recall']:.4f}")
print(f"F1值: {metrics['f1']:.4f}")
return metrics, all_labels, all_preds, np.array(all_probs)
def plot_roc_curve(all_labels, all_probs, save_path=None):
"""绘制ROC曲线"""
# 对于二分类问题,取阳性类(异常类别)的概率
pos_probs = all_probs[:, 1]
# 计算ROC曲线
fpr, tpr, thresholds = roc_curve(all_labels, pos_probs)
roc_auc = auc(fpr, tpr)
# 绘制ROC曲线
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('假阳性率')
plt.ylabel('真阳性率')
plt.title('受试者工作特征(ROC)曲线')
plt.legend(loc='lower right')
plt.grid(True)
if save_path:
plt.savefig(save_path)
plt.show()
return roc_auc
def extract_and_visualize_features(model, data_loader, device, save_path=None):
"""提取特征并使用t-SNE可视化"""
model.eval()
features_dict = {
'diff': [],
'wnb': []
}
all_labels = []
with torch.no_grad():
for batch in data_loader:
# 获取数据和标签
diff_imgs = batch['diff_img'].to(device)
wnb_imgs = batch['wnb_img'].to(device)
labels = batch['label'].cpu().numpy()
# 提取特征
batch_features = model.extract_features(diff_imgs, wnb_imgs)
# 保存特征和标签
for modality in features_dict:
features_dict[modality].extend(batch_features[modality].cpu().numpy())
all_labels.extend(labels)
# 转换为NumPy数组
all_labels = np.array(all_labels)
# 可视化每个模态的特征
plt.figure(figsize=(15, 5))
for i, (modality, features) in enumerate(features_dict.items()):
features = np.array(features)
# 使用t-SNE降维
tsne = TSNE(n_components=2, random_state=42)
features_tsne = tsne.fit_transform(features)
# 绘制t-SNE结果
plt.subplot(1, 3, i+1)
for label in np.unique(all_labels):
idx = all_labels == label
plt.scatter(features_tsne[idx, 0], features_tsne[idx, 1],
label=f"类别 {label}", alpha=0.7)
plt.title(f'{modality} 特征 t-SNE 可视化')
plt.xlabel('t-SNE 特征 1')
plt.ylabel('t-SNE 特征 2')
plt.legend()
plt.grid(True)
# 将两种模态的特征连接起来进行可视化
combined_features = np.concatenate([features_dict['diff'], features_dict['wnb']], axis=1)
tsne = TSNE(n_components=2, random_state=42)
combined_tsne = tsne.fit_transform(combined_features)
plt.subplot(1, 3, 3)
for label in np.unique(all_labels):
idx = all_labels == label
plt.scatter(combined_tsne[idx, 0], combined_tsne[idx, 1],
label=f"类别 {label}", alpha=0.7)
plt.title('融合特征 t-SNE 可视化')
plt.xlabel('t-SNE 特征 1')
plt.ylabel('t-SNE 特征 2')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.show()
def compare_models(models_results, model_names, save_path=None):
"""比较不同模型的性能"""
metrics = ['accuracy', 'precision', 'recall', 'f1']
values = []
for result in models_results:
values.append([result[metric] for metric in metrics])
values = np.array(values)
# 绘制条形图比较
plt.figure(figsize=(10, 6))
x = np.arange(len(metrics))
width = 0.8 / len(models_results)
for i, (name, vals) in enumerate(zip(model_names, values)):
plt.bar(x + i * width, vals, width, label=name)
plt.xlabel('评估指标')
plt.ylabel('分数')
plt.title('不同模型性能比较')
plt.xticks(x + width * (len(models_results) - 1) / 2, metrics)
plt.ylim(0, 1.0)
plt.legend()
plt.grid(True, axis='y')
if save_path:
plt.savefig(save_path)
plt.show()

View File

349
main.py Normal file
View File

@ -0,0 +1,349 @@
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import argparse
import matplotlib.pyplot as plt
from config import *
from data_preprocessing.data_loader import load_data
from models.image_models import VisionTransformer
from models.fusion_model import MultiModalFusionModel, SingleModalModel
from training.train import train_model, train_single_modal_model
from training.utils import plot_training_curves, save_results
from evaluation.evaluate import evaluate_model, plot_roc_curve, extract_and_visualize_features, compare_models
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='白血病智能筛查系统')
parser.add_argument('--data_root', type=str, default=DATA_ROOT, help='数据根目录')
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='输出目录')
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='批大小')
parser.add_argument('--epochs', type=int, default=NUM_EPOCHS, help='训练轮数')
parser.add_argument('--lr', type=float, default=LEARNING_RATE, help='学习率')
parser.add_argument('--weight_decay', type=float, default=WEIGHT_DECAY, help='权重衰减')
parser.add_argument('--mode', type=str, choices=['train', 'evaluate', 'compare'], default='train', help='运行模式')
args = parser.parse_args()
# 确保输出目录存在
os.makedirs(args.output_dir, exist_ok=True)
# 加载数据
print("加载数据...")
train_loader, val_loader = load_data(
data_root=args.data_root,
img_size=IMG_SIZE,
batch_size=args.batch_size,
num_workers=NUM_WORKERS,
train_ratio=TRAIN_RATIO,
max_samples_per_class=MAX_SAMPLES_PER_CLASS,
normal_class=NORMAL_CLASS,
abnormal_classes=ABNORMAL_CLASSES
)
print(f"数据加载完成。训练集批次数: {len(train_loader)}, 验证集批次数: {len(val_loader)}")
# 设置设备
device = DEVICE
print(f"使用设备: {device}")
if args.mode == 'train':
# 创建多模态模型
print("创建多模态融合模型...")
multi_modal_model = MultiModalFusionModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# 训练多模态模型
print("开始训练多模态融合模型...")
multi_modal_model, multi_modal_history = train_model(
model=multi_modal_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='multi_modal'
)
# 可视化训练历史
plot_training_curves(
multi_modal_history['train_losses'],
multi_modal_history['val_losses'],
multi_modal_history['train_accs'],
multi_modal_history['val_accs'],
save_path=os.path.join(args.output_dir, 'multi_modal_training_curves.png')
)
# 评估多模态模型
print("\n评估多模态融合模型...")
multi_modal_metrics, labels, preds, probs = evaluate_model(
model=multi_modal_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
# 绘制ROC曲线
plot_roc_curve(
labels,
probs,
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
)
# 提取和可视化特征
extract_and_visualize_features(
model=multi_modal_model,
data_loader=val_loader,
device=device,
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
)
# 保存结果
save_results(
multi_modal_metrics,
os.path.join(args.output_dir, 'multi_modal_results.txt')
)
elif args.mode == 'compare':
print("创建模型进行比较...")
# 创建单模态模型 - DIFF
diff_model = SingleModalModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 创建单模态模型 - WNB
wnb_model = SingleModalModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 创建多模态模型
multi_modal_model = MultiModalFusionModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 训练DIFF单模态模型
print("训练DIFF散点图单模态模型...")
diff_optimizer = optim.AdamW(diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
diff_model, diff_history = train_single_modal_model(
model=diff_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=diff_optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='diff_only',
modal_key='diff_img'
)
# 训练WNB单模态模型
print("训练WNB散点图单模态模型...")
wnb_optimizer = optim.AdamW(wnb_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
wnb_model, wnb_history = train_single_modal_model(
model=wnb_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=wnb_optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='wnb_only',
modal_key='wnb_img'
)
# 训练多模态模型
print("训练多模态融合模型...")
multi_optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
multi_modal_model, multi_history = train_model(
model=multi_modal_model,
train_loader=train_loader,
val_loader=val_loader,
criterion=criterion,
optimizer=multi_optimizer,
device=device,
num_epochs=args.epochs,
save_dir=MODEL_SAVE_DIR,
model_name='multi_modal'
)
# 评估并比较模型
print("\n评估DIFF散点图单模态模型...")
diff_metrics, _, _, _ = evaluate_model(
model=diff_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
print("\n评估WNB散点图单模态模型...")
wnb_metrics, _, _, _ = evaluate_model(
model=wnb_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
print("\n评估多模态融合模型...")
multi_metrics, _, _, _ = evaluate_model(
model=multi_modal_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
# 比较不同模型的性能
compare_models(
[diff_metrics, wnb_metrics, multi_metrics],
['DIFF散点图', 'WNB散点图', '多模态融合'],
save_path=os.path.join(args.output_dir, 'model_comparison.png')
)
# 可视化训练曲线
plt.figure(figsize=(12, 8))
plt.subplot(2, 2, 1)
plt.plot(diff_history['train_losses'], label='DIFF训练')
plt.plot(wnb_history['train_losses'], label='WNB训练')
plt.plot(multi_history['train_losses'], label='多模态训练')
plt.title('训练损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.subplot(2, 2, 2)
plt.plot(diff_history['val_losses'], label='DIFF验证')
plt.plot(wnb_history['val_losses'], label='WNB验证')
plt.plot(multi_history['val_losses'], label='多模态验证')
plt.title('验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.subplot(2, 2, 3)
plt.plot(diff_history['train_accs'], label='DIFF训练')
plt.plot(wnb_history['train_accs'], label='WNB训练')
plt.plot(multi_history['train_accs'], label='多模态训练')
plt.title('训练准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.subplot(2, 2, 4)
plt.plot(diff_history['val_accs'], label='DIFF验证')
plt.plot(wnb_history['val_accs'], label='WNB验证')
plt.plot(multi_history['val_accs'], label='多模态验证')
plt.title('验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(args.output_dir, 'all_models_training_curves.png'))
plt.show()
# 保存结果
save_results(diff_metrics, os.path.join(args.output_dir, 'diff_model_results.txt'))
save_results(wnb_metrics, os.path.join(args.output_dir, 'wnb_model_results.txt'))
save_results(multi_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt'))
elif args.mode == 'evaluate':
# 加载预训练的多模态模型
print("加载预训练的多模态模型...")
model_path = os.path.join(MODEL_SAVE_DIR, 'multi_modal_best.pth')
if not os.path.exists(model_path):
print(f"错误:找不到预训练模型 {model_path}")
return
multi_modal_model = MultiModalFusionModel(
img_size=IMG_SIZE,
patch_size=16,
in_channels=3,
embed_dim=HIDDEN_DIM,
depth=NUM_LAYERS,
num_heads=NUM_HEADS,
dropout=DROPOUT,
num_classes=NUM_CLASSES
).to(device)
multi_modal_model.load_state_dict(torch.load(model_path))
# 评估模型
print("评估多模态融合模型...")
multi_modal_metrics, labels, preds, probs = evaluate_model(
model=multi_modal_model,
data_loader=val_loader,
device=device,
class_names=['正常', '异常']
)
# 绘制ROC曲线
plot_roc_curve(
labels,
probs,
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
)
# 提取和可视化特征
extract_and_visualize_features(
model=multi_modal_model,
data_loader=val_loader,
device=device,
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
)
# 保存结果
save_results(
multi_modal_metrics,
os.path.join(args.output_dir, 'multi_modal_results.txt')
)
if __name__ == "__main__":
main()

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

116
models/fusion_model.py Normal file
View File

@ -0,0 +1,116 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.image_models import VisionTransformer
class MultiModalFusionModel(nn.Module):
"""多模态融合模型融合DIFF和WNB散点图的特征"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
super().__init__()
# DIFF散点图特征提取器
self.diff_encoder = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
dropout=dropout
)
# WNB散点图特征提取器
self.wnb_encoder = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
dropout=dropout
)
# 特征融合层
self.fusion = nn.Sequential(
nn.Linear(embed_dim * 2, embed_dim),
nn.LayerNorm(embed_dim),
nn.GELU(),
nn.Dropout(dropout)
)
# 分类头
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, diff_img, wnb_img):
"""
前向传播
Args:
diff_img: DIFF散点图 [B, C, H, W]
wnb_img: WNB散点图 [B, C, H, W]
Returns:
logits: 分类logits [B, num_classes]
"""
# 提取特征
diff_features = self.diff_encoder(diff_img) # [B, E]
wnb_features = self.wnb_encoder(wnb_img) # [B, E]
# 特征融合
combined_features = torch.cat([diff_features, wnb_features], dim=1) # [B, 2*E]
fused_features = self.fusion(combined_features) # [B, E]
# 分类
logits = self.classifier(fused_features) # [B, num_classes]
return logits
def extract_features(self, diff_img, wnb_img):
"""提取各个模态的特征,用于分析"""
diff_features = self.diff_encoder(diff_img)
wnb_features = self.wnb_encoder(wnb_img)
return {
'diff': diff_features,
'wnb': wnb_features
}
class SingleModalModel(nn.Module):
"""单模态模型,用于对比实验"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
super().__init__()
# 图像特征提取器
self.encoder = VisionTransformer(
img_size=img_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
dropout=dropout
)
# 分类头
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, img):
"""
前向传播
Args:
img: 输入图像 [B, C, H, W]
Returns:
logits: 分类logits [B, num_classes]
"""
# 提取特征
features = self.encoder(img) # [B, E]
# 分类
logits = self.classifier(features) # [B, num_classes]
return logits

78
models/image_models.py Normal file
View File

@ -0,0 +1,78 @@
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import torch.nn.functional as F
import torchvision.models as models
class PatchEmbedding(nn.Module):
"""将图像分割为patch并进行embedding"""
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
# x: [B, C, H, W]
batch_size = x.shape[0]
x = self.proj(x) # [B, E, H/P, W/P]
x = x.flatten(2) # [B, E, (H/P)*(W/P)]
x = x.transpose(1, 2) # [B, (H/P)*(W/P), E]
return x
class VisionTransformer(nn.Module):
"""基于Transformer的图像特征提取模型"""
def __init__(self, img_size=224, patch_size=16, in_channels=3,
embed_dim=768, depth=6, num_heads=12, dropout=0.1):
super().__init__()
# Patch Embedding
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
self.n_patches = self.patch_embed.n_patches
# Position Embedding
self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Transformer Encoder
encoder_layer = TransformerEncoderLayer(
d_model=embed_dim,
nhead=num_heads,
dim_feedforward=embed_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True
)
self.transformer = TransformerEncoder(encoder_layer, num_layers=depth)
# 层归一化
self.norm = nn.LayerNorm(embed_dim)
# 初始化
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
def forward(self, x):
# x: [B, C, H, W]
batch_size = x.shape[0]
# Patch Embedding: [B, N, E]
x = self.patch_embed(x)
# 添加CLS token
cls_token = self.cls_token.expand(batch_size, -1, -1) # [B, 1, E]
x = torch.cat([cls_token, x], dim=1) # [B, N+1, E]
# 添加Position Embedding
x = x + self.pos_embed
# Transformer Encoder
x = self.transformer(x)
# 提取CLS token作为整个图像的特征
x = x[:, 0] # [B, E]
return x

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

312
training/train.py Normal file
View File

@ -0,0 +1,312 @@
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
from tqdm import tqdm
import time
from training.utils import AverageMeter, EarlyStopping, plot_training_curves, compute_metrics
def train_epoch(model, train_loader, criterion, optimizer, device):
"""训练一个epoch"""
model.train()
losses = AverageMeter()
acc = AverageMeter()
# 进度条
pbar = tqdm(train_loader, desc='训练')
for batch in pbar:
# 获取数据和标签
diff_imgs = batch['diff_img'].to(device)
wnb_imgs = batch['wnb_img'].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(diff_imgs, wnb_imgs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算准确率
_, preds = torch.max(outputs, 1)
batch_acc = (preds == labels).float().mean()
# 更新统计
losses.update(loss.item(), labels.size(0))
acc.update(batch_acc.item(), labels.size(0))
# 更新进度条
pbar.set_postfix({
'loss': losses.avg,
'acc': acc.avg
})
return losses.avg, acc.avg
def validate(model, val_loader, criterion, device):
"""验证模型"""
model.eval()
losses = AverageMeter()
acc = AverageMeter()
all_labels = []
all_preds = []
all_probs = []
with torch.no_grad():
for batch in tqdm(val_loader, desc='验证'):
# 获取数据和标签
diff_imgs = batch['diff_img'].to(device)
wnb_imgs = batch['wnb_img'].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(diff_imgs, wnb_imgs)
# 计算损失
loss = criterion(outputs, labels)
# 计算准确率
probs = torch.softmax(outputs, dim=1)
_, preds = torch.max(outputs, 1)
batch_acc = (preds == labels).float().mean()
# 更新统计
losses.update(loss.item(), labels.size(0))
acc.update(batch_acc.item(), labels.size(0))
# 保存预测结果用于计算指标
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# 计算其他评估指标
metrics = compute_metrics(all_labels, all_preds, all_probs)
metrics['loss'] = losses.avg
metrics['accuracy'] = acc.avg
return losses.avg, acc.avg, metrics
def train_model(model, train_loader, val_loader, criterion, optimizer, device,
num_epochs=50, save_dir='./models', model_name='model'):
"""训练模型"""
# 保存最佳模型的路径
if not os.path.exists(save_dir):
os.makedirs(save_dir)
best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
# 初始化早停
early_stopping = EarlyStopping(patience=10, path=best_model_path)
# 跟踪训练历史
train_losses = []
val_losses = []
train_accs = []
val_accs = []
best_val_acc = 0.0
# 开始训练
start_time = time.time()
for epoch in range(num_epochs):
print(f'\nEpoch {epoch+1}/{num_epochs}')
print('-' * 20)
# 训练
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
train_losses.append(train_loss)
train_accs.append(train_acc)
# 验证
val_loss, val_acc, val_metrics = validate(model, val_loader, criterion, device)
val_losses.append(val_loss)
val_accs.append(val_acc)
# 打印当前epoch的结果
print(f'训练损失: {train_loss:.4f} 训练准确率: {train_acc:.4f}')
print(f'验证损失: {val_loss:.4f} 验证准确率: {val_acc:.4f}')
print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
# 检查是否为最佳验证准确率
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc,
'val_metrics': val_metrics
}, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
print(f'保存新的最佳模型,验证准确率: {val_acc:.4f}')
# 早停检查
early_stopping(val_loss, model)
if early_stopping.early_stop:
print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
break
# 计算总训练时间
total_time = time.time() - start_time
print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
# 加载最佳模型
model.load_state_dict(torch.load(best_model_path))
return model, {
'train_losses': train_losses,
'val_losses': val_losses,
'train_accs': train_accs,
'val_accs': val_accs,
'best_val_acc': best_val_acc,
'total_time': total_time
}
def train_single_modal_model(model, train_loader, val_loader, criterion, optimizer, device,
num_epochs=50, save_dir='./models', model_name='single_modal', modal_key='diff_img'):
"""训练单模态模型"""
# 保存最佳模型的路径
if not os.path.exists(save_dir):
os.makedirs(save_dir)
best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
# 初始化早停
early_stopping = EarlyStopping(patience=10, path=best_model_path)
# 跟踪训练历史
train_losses = []
val_losses = []
train_accs = []
val_accs = []
best_val_acc = 0.0
# 开始训练
start_time = time.time()
for epoch in range(num_epochs):
print(f'\nEpoch {epoch+1}/{num_epochs}')
print('-' * 20)
# 训练
model.train()
train_loss = AverageMeter()
train_acc = AverageMeter()
pbar = tqdm(train_loader, desc='训练')
for batch in pbar:
# 获取数据和标签
imgs = batch[modal_key].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(imgs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算准确率
_, preds = torch.max(outputs, 1)
batch_acc = (preds == labels).float().mean()
# 更新统计
train_loss.update(loss.item(), labels.size(0))
train_acc.update(batch_acc.item(), labels.size(0))
# 更新进度条
pbar.set_postfix({
'loss': train_loss.avg,
'acc': train_acc.avg
})
train_losses.append(train_loss.avg)
train_accs.append(train_acc.avg)
# 验证
model.eval()
val_loss = AverageMeter()
val_acc = AverageMeter()
all_labels = []
all_preds = []
with torch.no_grad():
for batch in tqdm(val_loader, desc='验证'):
# 获取数据和标签
imgs = batch[modal_key].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(imgs)
# 计算损失
loss = criterion(outputs, labels)
# 计算准确率
_, preds = torch.max(outputs, 1)
batch_acc = (preds == labels).float().mean()
# 更新统计
val_loss.update(loss.item(), labels.size(0))
val_acc.update(batch_acc.item(), labels.size(0))
# 保存预测结果用于计算指标
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
val_losses.append(val_loss.avg)
val_accs.append(val_acc.avg)
# 计算其他评估指标
val_metrics = compute_metrics(all_labels, all_preds)
# 打印当前epoch的结果
print(f'训练损失: {train_loss.avg:.4f} 训练准确率: {train_acc.avg:.4f}')
print(f'验证损失: {val_loss.avg:.4f} 验证准确率: {val_acc.avg:.4f}')
print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
# 检查是否为最佳验证准确率
if val_acc.avg > best_val_acc:
best_val_acc = val_acc.avg
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc.avg,
'val_metrics': val_metrics
}, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
print(f'保存新的最佳模型,验证准确率: {val_acc.avg:.4f}')
# 早停检查
early_stopping(val_loss.avg, model)
if early_stopping.early_stop:
print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
break
# 计算总训练时间
total_time = time.time() - start_time
print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
# 加载最佳模型
model.load_state_dict(torch.load(best_model_path))
return model, {
'train_losses': train_losses,
'val_losses': val_losses,
'train_accs': train_accs,
'val_accs': val_accs,
'best_val_acc': best_val_acc,
'total_time': total_time
}

126
training/utils.py Normal file
View File

@ -0,0 +1,126 @@
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import os
class AverageMeter:
"""跟踪平均值和当前值"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class EarlyStopping:
"""提前停止训练,避免过拟合"""
def __init__(self, patience=7, delta=0, path='checkpoint.pt'):
self.patience = patience
self.delta = delta
self.path = path
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
torch.save(model.state_dict(), self.path)
print(f'验证损失降低 ({self.best_score:.6f} --> {-val_loss:.6f}). 保存模型...')
def plot_training_curves(train_losses, val_losses, train_accs, val_accs, save_path=None):
"""绘制训练和验证的损失与准确率曲线"""
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='训练损失')
plt.plot(val_losses, label='验证损失')
plt.title('损失曲线')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='训练准确率')
plt.plot(val_accs, label='验证准确率')
plt.title('准确率曲线')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.show()
def plot_confusion_matrix(y_true, y_pred, class_names=None, save_path=None):
"""绘制混淆矩阵"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names if class_names else "auto",
yticklabels=class_names if class_names else "auto")
plt.title('混淆矩阵')
plt.xlabel('预测标签')
plt.ylabel('真实标签')
if save_path:
plt.savefig(save_path)
plt.show()
def compute_metrics(y_true, y_pred, y_proba=None):
"""计算各种评估指标"""
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
recall = recall_score(y_true, y_pred, average='binary')
f1 = f1_score(y_true, y_pred, average='binary')
metrics = {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1
}
return metrics
def save_results(results, filename):
"""保存结果到文本文件"""
with open(filename, 'w') as f:
for key, value in results.items():
f.write(f"{key}: {value}\n")