已跑通一次,准确率85%,500样本量
This commit is contained in:
parent
2e859eea14
commit
d1ab4a1f19
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
data/
|
||||
_pycache_/
|
||||
output/
|
||||
BIN
__pycache__/config.cpython-312.pyc
Normal file
BIN
__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
__pycache__/config.cpython-38.pyc
Normal file
BIN
__pycache__/config.cpython-38.pyc
Normal file
Binary file not shown.
36
config.py
Normal file
36
config.py
Normal file
@ -0,0 +1,36 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
# 路径配置
|
||||
DATA_ROOT = "./data" # 数据根目录,包含000,001等子文件夹
|
||||
OUTPUT_DIR = "./output"
|
||||
MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models")
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||||
|
||||
# 数据配置
|
||||
IMG_SIZE = 224 # 调整图像大小
|
||||
BATCH_SIZE = 32
|
||||
NUM_WORKERS = 4
|
||||
TRAIN_RATIO = 0.8
|
||||
VAL_RATIO = 0.2
|
||||
|
||||
# 样本数量控制
|
||||
MAX_SAMPLES_PER_CLASS = 1000 # 每个类别最多读取的样本数
|
||||
NORMAL_CLASS = "000" # 正常类别的文件夹名
|
||||
ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"] # 异常类别的文件夹名
|
||||
|
||||
# 模型配置
|
||||
HIDDEN_DIM = 768
|
||||
NUM_HEADS = 12
|
||||
NUM_LAYERS = 6
|
||||
DROPOUT = 0.1
|
||||
NUM_CLASSES = 2 # 二分类:正常 vs 异常
|
||||
|
||||
# 训练配置
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
LEARNING_RATE = 1e-4
|
||||
WEIGHT_DECAY = 1e-5
|
||||
NUM_EPOCHS = 50
|
||||
EARLY_STOPPING_PATIENCE = 10
|
||||
BIN
data_preprocessing/__pycache__/data_loader.cpython-312.pyc
Normal file
BIN
data_preprocessing/__pycache__/data_loader.cpython-312.pyc
Normal file
Binary file not shown.
BIN
data_preprocessing/__pycache__/data_loader.cpython-38.pyc
Normal file
BIN
data_preprocessing/__pycache__/data_loader.cpython-38.pyc
Normal file
Binary file not shown.
171
data_preprocessing/data_loader.py
Normal file
171
data_preprocessing/data_loader.py
Normal file
@ -0,0 +1,171 @@
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
class LeukemiaDataset(Dataset):
|
||||
def __init__(self, sample_paths, labels, transform=None):
|
||||
"""
|
||||
初始化白血病数据集
|
||||
|
||||
Args:
|
||||
sample_paths: 样本路径列表,每个元素是(diff_path, wnb_path)元组
|
||||
labels: 对应的标签列表 (0: 正常, 1: 异常)
|
||||
transform: 图像转换操作
|
||||
"""
|
||||
self.sample_paths = sample_paths
|
||||
self.labels = labels
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
diff_path, wnb_path = self.sample_paths[idx]
|
||||
|
||||
# 从路径中提取样本ID
|
||||
sample_id = os.path.basename(diff_path).replace('Diff.png', '')
|
||||
|
||||
# 加载DIFF散点图
|
||||
diff_img = Image.open(diff_path).convert('RGB')
|
||||
|
||||
# 加载WNB散点图
|
||||
wnb_img = Image.open(wnb_path).convert('RGB')
|
||||
|
||||
if self.transform:
|
||||
diff_img = self.transform(diff_img)
|
||||
wnb_img = self.transform(wnb_img)
|
||||
|
||||
label = self.labels[idx]
|
||||
|
||||
return {
|
||||
'id': sample_id,
|
||||
'diff_img': diff_img,
|
||||
'wnb_img': wnb_img,
|
||||
'label': torch.tensor(label, dtype=torch.long)
|
||||
}
|
||||
|
||||
|
||||
def load_data(data_root, img_size=224, batch_size=32, num_workers=4, train_ratio=0.8,
|
||||
max_samples_per_class=500, normal_class="000", abnormal_classes=None):
|
||||
"""
|
||||
加载和预处理数据
|
||||
|
||||
Args:
|
||||
data_root: 数据根目录,包含000,001等子文件夹
|
||||
img_size: 调整图像大小
|
||||
batch_size: 批量大小
|
||||
num_workers: 数据加载的工作线程数
|
||||
train_ratio: 训练集比例
|
||||
max_samples_per_class: 每个类别最多读取的样本数
|
||||
normal_class: 正常类别的文件夹名
|
||||
abnormal_classes: 异常类别的文件夹名列表
|
||||
|
||||
Returns:
|
||||
train_loader: 训练数据加载器
|
||||
val_loader: 验证数据加载器
|
||||
"""
|
||||
if abnormal_classes is None:
|
||||
abnormal_classes = ["001", "010", "011", "100", "101", "110", "111"]
|
||||
|
||||
# 图像转换
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((img_size, img_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# 收集正常样本数据
|
||||
normal_samples = []
|
||||
normal_folder = os.path.join(data_root, normal_class)
|
||||
diff_files = glob.glob(os.path.join(normal_folder, "*Diff.png"))
|
||||
|
||||
# 限制正常样本数量
|
||||
if len(diff_files) > max_samples_per_class:
|
||||
random.shuffle(diff_files)
|
||||
diff_files = diff_files[:max_samples_per_class]
|
||||
|
||||
for diff_file in diff_files:
|
||||
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
|
||||
wnb_file = os.path.join(normal_folder, f"{sample_id}Wnb.png")
|
||||
|
||||
# 确保WNB文件存在
|
||||
if os.path.exists(wnb_file):
|
||||
normal_samples.append((diff_file, wnb_file))
|
||||
|
||||
# 收集异常样本数据
|
||||
abnormal_samples = []
|
||||
|
||||
# 计算每个异常类别分配的样本数
|
||||
# 如果max_samples_per_class=500,总共取500个异常样本,平均分配给各个异常类别
|
||||
samples_per_abnormal_class = max_samples_per_class // len(abnormal_classes)
|
||||
|
||||
for abnormal_class in abnormal_classes:
|
||||
abnormal_folder = os.path.join(data_root, abnormal_class)
|
||||
diff_files = glob.glob(os.path.join(abnormal_folder, "*Diff.png"))
|
||||
|
||||
# 限制每个异常类别的样本数量
|
||||
if len(diff_files) > samples_per_abnormal_class:
|
||||
random.shuffle(diff_files)
|
||||
diff_files = diff_files[:samples_per_abnormal_class]
|
||||
|
||||
for diff_file in diff_files:
|
||||
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
|
||||
wnb_file = os.path.join(abnormal_folder, f"{sample_id}Wnb.png")
|
||||
|
||||
# 确保WNB文件存在
|
||||
if os.path.exists(wnb_file):
|
||||
abnormal_samples.append((diff_file, wnb_file))
|
||||
|
||||
# 准备数据集
|
||||
all_samples = normal_samples + abnormal_samples
|
||||
all_labels = [0] * len(normal_samples) + [1] * len(abnormal_samples)
|
||||
|
||||
print(f"收集到的正常样本数: {len(normal_samples)}")
|
||||
print(f"收集到的异常样本数: {len(abnormal_samples)}")
|
||||
print(f"总样本数: {len(all_samples)}")
|
||||
|
||||
# 划分训练集和验证集
|
||||
indices = np.arange(len(all_samples))
|
||||
train_indices, val_indices = train_test_split(
|
||||
indices,
|
||||
test_size=(1-train_ratio),
|
||||
stratify=all_labels, # 确保训练集和验证集中类别比例一致
|
||||
random_state=42
|
||||
)
|
||||
|
||||
# 准备训练集和验证集
|
||||
train_samples = [all_samples[i] for i in train_indices]
|
||||
train_labels = [all_labels[i] for i in train_indices]
|
||||
|
||||
val_samples = [all_samples[i] for i in val_indices]
|
||||
val_labels = [all_labels[i] for i in val_indices]
|
||||
|
||||
# 创建数据集
|
||||
train_dataset = LeukemiaDataset(train_samples, train_labels, transform)
|
||||
val_dataset = LeukemiaDataset(val_samples, val_labels, transform)
|
||||
|
||||
# 创建数据加载器
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers
|
||||
)
|
||||
|
||||
print(f"训练集样本数: {len(train_dataset)}, 验证集样本数: {len(val_dataset)}")
|
||||
|
||||
return train_loader, val_loader
|
||||
0
data_preprocessing/data_split.py
Normal file
0
data_preprocessing/data_split.py
Normal file
BIN
evaluation/__pycache__/evaluate.cpython-312.pyc
Normal file
BIN
evaluation/__pycache__/evaluate.cpython-312.pyc
Normal file
Binary file not shown.
BIN
evaluation/__pycache__/evaluate.cpython-38.pyc
Normal file
BIN
evaluation/__pycache__/evaluate.cpython-38.pyc
Normal file
Binary file not shown.
185
evaluation/evaluate.py
Normal file
185
evaluation/evaluate.py
Normal file
@ -0,0 +1,185 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
|
||||
from sklearn.manifold import TSNE
|
||||
import seaborn as sns
|
||||
from training.utils import compute_metrics, plot_confusion_matrix, save_results
|
||||
|
||||
def evaluate_model(model, data_loader, device, class_names=None):
|
||||
"""评估模型性能"""
|
||||
model.eval()
|
||||
|
||||
all_labels = []
|
||||
all_preds = []
|
||||
all_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in data_loader:
|
||||
# 获取数据和标签
|
||||
diff_imgs = batch['diff_img'].to(device)
|
||||
wnb_imgs = batch['wnb_img'].to(device)
|
||||
labels = batch['label'].to(device)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(diff_imgs, wnb_imgs)
|
||||
|
||||
# 获取预测和概率
|
||||
probs = torch.softmax(outputs, dim=1)
|
||||
_, preds = torch.max(outputs, 1)
|
||||
|
||||
# 保存结果
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
all_preds.extend(preds.cpu().numpy())
|
||||
all_probs.extend(probs.cpu().numpy())
|
||||
|
||||
# 计算评估指标
|
||||
metrics = compute_metrics(all_labels, all_preds, all_probs)
|
||||
|
||||
# 绘制混淆矩阵
|
||||
if class_names is None:
|
||||
class_names = ['正常', '异常']
|
||||
plot_confusion_matrix(all_labels, all_preds, class_names)
|
||||
|
||||
# 输出评估结果
|
||||
print(f"准确率: {metrics['accuracy']:.4f}")
|
||||
print(f"精确率: {metrics['precision']:.4f}")
|
||||
print(f"召回率: {metrics['recall']:.4f}")
|
||||
print(f"F1值: {metrics['f1']:.4f}")
|
||||
|
||||
return metrics, all_labels, all_preds, np.array(all_probs)
|
||||
|
||||
|
||||
def plot_roc_curve(all_labels, all_probs, save_path=None):
|
||||
"""绘制ROC曲线"""
|
||||
# 对于二分类问题,取阳性类(异常类别)的概率
|
||||
pos_probs = all_probs[:, 1]
|
||||
|
||||
# 计算ROC曲线
|
||||
fpr, tpr, thresholds = roc_curve(all_labels, pos_probs)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
# 绘制ROC曲线
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})')
|
||||
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
||||
plt.xlim([0.0, 1.0])
|
||||
plt.ylim([0.0, 1.05])
|
||||
plt.xlabel('假阳性率')
|
||||
plt.ylabel('真阳性率')
|
||||
plt.title('受试者工作特征(ROC)曲线')
|
||||
plt.legend(loc='lower right')
|
||||
plt.grid(True)
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
|
||||
plt.show()
|
||||
|
||||
return roc_auc
|
||||
|
||||
|
||||
def extract_and_visualize_features(model, data_loader, device, save_path=None):
|
||||
"""提取特征并使用t-SNE可视化"""
|
||||
model.eval()
|
||||
|
||||
features_dict = {
|
||||
'diff': [],
|
||||
'wnb': []
|
||||
}
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in data_loader:
|
||||
# 获取数据和标签
|
||||
diff_imgs = batch['diff_img'].to(device)
|
||||
wnb_imgs = batch['wnb_img'].to(device)
|
||||
labels = batch['label'].cpu().numpy()
|
||||
|
||||
# 提取特征
|
||||
batch_features = model.extract_features(diff_imgs, wnb_imgs)
|
||||
|
||||
# 保存特征和标签
|
||||
for modality in features_dict:
|
||||
features_dict[modality].extend(batch_features[modality].cpu().numpy())
|
||||
all_labels.extend(labels)
|
||||
|
||||
# 转换为NumPy数组
|
||||
all_labels = np.array(all_labels)
|
||||
|
||||
# 可视化每个模态的特征
|
||||
plt.figure(figsize=(15, 5))
|
||||
|
||||
for i, (modality, features) in enumerate(features_dict.items()):
|
||||
features = np.array(features)
|
||||
|
||||
# 使用t-SNE降维
|
||||
tsne = TSNE(n_components=2, random_state=42)
|
||||
features_tsne = tsne.fit_transform(features)
|
||||
|
||||
# 绘制t-SNE结果
|
||||
plt.subplot(1, 3, i+1)
|
||||
for label in np.unique(all_labels):
|
||||
idx = all_labels == label
|
||||
plt.scatter(features_tsne[idx, 0], features_tsne[idx, 1],
|
||||
label=f"类别 {label}", alpha=0.7)
|
||||
plt.title(f'{modality} 特征 t-SNE 可视化')
|
||||
plt.xlabel('t-SNE 特征 1')
|
||||
plt.ylabel('t-SNE 特征 2')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# 将两种模态的特征连接起来进行可视化
|
||||
combined_features = np.concatenate([features_dict['diff'], features_dict['wnb']], axis=1)
|
||||
tsne = TSNE(n_components=2, random_state=42)
|
||||
combined_tsne = tsne.fit_transform(combined_features)
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
for label in np.unique(all_labels):
|
||||
idx = all_labels == label
|
||||
plt.scatter(combined_tsne[idx, 0], combined_tsne[idx, 1],
|
||||
label=f"类别 {label}", alpha=0.7)
|
||||
plt.title('融合特征 t-SNE 可视化')
|
||||
plt.xlabel('t-SNE 特征 1')
|
||||
plt.ylabel('t-SNE 特征 2')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def compare_models(models_results, model_names, save_path=None):
|
||||
"""比较不同模型的性能"""
|
||||
metrics = ['accuracy', 'precision', 'recall', 'f1']
|
||||
values = []
|
||||
|
||||
for result in models_results:
|
||||
values.append([result[metric] for metric in metrics])
|
||||
|
||||
values = np.array(values)
|
||||
|
||||
# 绘制条形图比较
|
||||
plt.figure(figsize=(10, 6))
|
||||
x = np.arange(len(metrics))
|
||||
width = 0.8 / len(models_results)
|
||||
|
||||
for i, (name, vals) in enumerate(zip(model_names, values)):
|
||||
plt.bar(x + i * width, vals, width, label=name)
|
||||
|
||||
plt.xlabel('评估指标')
|
||||
plt.ylabel('分数')
|
||||
plt.title('不同模型性能比较')
|
||||
plt.xticks(x + width * (len(models_results) - 1) / 2, metrics)
|
||||
plt.ylim(0, 1.0)
|
||||
plt.legend()
|
||||
plt.grid(True, axis='y')
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
|
||||
plt.show()
|
||||
0
evaluation/visualization.py
Normal file
0
evaluation/visualization.py
Normal file
349
main.py
Normal file
349
main.py
Normal file
@ -0,0 +1,349 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
from config import *
|
||||
|
||||
from data_preprocessing.data_loader import load_data
|
||||
from models.image_models import VisionTransformer
|
||||
from models.fusion_model import MultiModalFusionModel, SingleModalModel
|
||||
from training.train import train_model, train_single_modal_model
|
||||
from training.utils import plot_training_curves, save_results
|
||||
from evaluation.evaluate import evaluate_model, plot_roc_curve, extract_and_visualize_features, compare_models
|
||||
|
||||
|
||||
def main():
|
||||
# 解析命令行参数
|
||||
parser = argparse.ArgumentParser(description='白血病智能筛查系统')
|
||||
parser.add_argument('--data_root', type=str, default=DATA_ROOT, help='数据根目录')
|
||||
parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='输出目录')
|
||||
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE, help='批大小')
|
||||
parser.add_argument('--epochs', type=int, default=NUM_EPOCHS, help='训练轮数')
|
||||
parser.add_argument('--lr', type=float, default=LEARNING_RATE, help='学习率')
|
||||
parser.add_argument('--weight_decay', type=float, default=WEIGHT_DECAY, help='权重衰减')
|
||||
parser.add_argument('--mode', type=str, choices=['train', 'evaluate', 'compare'], default='train', help='运行模式')
|
||||
args = parser.parse_args()
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# 加载数据
|
||||
print("加载数据...")
|
||||
train_loader, val_loader = load_data(
|
||||
data_root=args.data_root,
|
||||
img_size=IMG_SIZE,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=NUM_WORKERS,
|
||||
train_ratio=TRAIN_RATIO,
|
||||
max_samples_per_class=MAX_SAMPLES_PER_CLASS,
|
||||
normal_class=NORMAL_CLASS,
|
||||
abnormal_classes=ABNORMAL_CLASSES
|
||||
)
|
||||
print(f"数据加载完成。训练集批次数: {len(train_loader)}, 验证集批次数: {len(val_loader)}")
|
||||
|
||||
# 设置设备
|
||||
device = DEVICE
|
||||
print(f"使用设备: {device}")
|
||||
|
||||
if args.mode == 'train':
|
||||
# 创建多模态模型
|
||||
print("创建多模态融合模型...")
|
||||
multi_modal_model = MultiModalFusionModel(
|
||||
img_size=IMG_SIZE,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dim=HIDDEN_DIM,
|
||||
depth=NUM_LAYERS,
|
||||
num_heads=NUM_HEADS,
|
||||
dropout=DROPOUT,
|
||||
num_classes=NUM_CLASSES
|
||||
).to(device)
|
||||
|
||||
# 定义损失函数和优化器
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
# 训练多模态模型
|
||||
print("开始训练多模态融合模型...")
|
||||
multi_modal_model, multi_modal_history = train_model(
|
||||
model=multi_modal_model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
criterion=criterion,
|
||||
optimizer=optimizer,
|
||||
device=device,
|
||||
num_epochs=args.epochs,
|
||||
save_dir=MODEL_SAVE_DIR,
|
||||
model_name='multi_modal'
|
||||
)
|
||||
|
||||
# 可视化训练历史
|
||||
plot_training_curves(
|
||||
multi_modal_history['train_losses'],
|
||||
multi_modal_history['val_losses'],
|
||||
multi_modal_history['train_accs'],
|
||||
multi_modal_history['val_accs'],
|
||||
save_path=os.path.join(args.output_dir, 'multi_modal_training_curves.png')
|
||||
)
|
||||
|
||||
# 评估多模态模型
|
||||
print("\n评估多模态融合模型...")
|
||||
multi_modal_metrics, labels, preds, probs = evaluate_model(
|
||||
model=multi_modal_model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
class_names=['正常', '异常']
|
||||
)
|
||||
|
||||
# 绘制ROC曲线
|
||||
plot_roc_curve(
|
||||
labels,
|
||||
probs,
|
||||
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
|
||||
)
|
||||
|
||||
# 提取和可视化特征
|
||||
extract_and_visualize_features(
|
||||
model=multi_modal_model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
|
||||
)
|
||||
|
||||
# 保存结果
|
||||
save_results(
|
||||
multi_modal_metrics,
|
||||
os.path.join(args.output_dir, 'multi_modal_results.txt')
|
||||
)
|
||||
|
||||
elif args.mode == 'compare':
|
||||
print("创建模型进行比较...")
|
||||
|
||||
# 创建单模态模型 - DIFF
|
||||
diff_model = SingleModalModel(
|
||||
img_size=IMG_SIZE,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dim=HIDDEN_DIM,
|
||||
depth=NUM_LAYERS,
|
||||
num_heads=NUM_HEADS,
|
||||
dropout=DROPOUT,
|
||||
num_classes=NUM_CLASSES
|
||||
).to(device)
|
||||
|
||||
# 创建单模态模型 - WNB
|
||||
wnb_model = SingleModalModel(
|
||||
img_size=IMG_SIZE,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dim=HIDDEN_DIM,
|
||||
depth=NUM_LAYERS,
|
||||
num_heads=NUM_HEADS,
|
||||
dropout=DROPOUT,
|
||||
num_classes=NUM_CLASSES
|
||||
).to(device)
|
||||
|
||||
# 创建多模态模型
|
||||
multi_modal_model = MultiModalFusionModel(
|
||||
img_size=IMG_SIZE,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dim=HIDDEN_DIM,
|
||||
depth=NUM_LAYERS,
|
||||
num_heads=NUM_HEADS,
|
||||
dropout=DROPOUT,
|
||||
num_classes=NUM_CLASSES
|
||||
).to(device)
|
||||
|
||||
# 定义损失函数
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# 训练DIFF单模态模型
|
||||
print("训练DIFF散点图单模态模型...")
|
||||
diff_optimizer = optim.AdamW(diff_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
diff_model, diff_history = train_single_modal_model(
|
||||
model=diff_model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
criterion=criterion,
|
||||
optimizer=diff_optimizer,
|
||||
device=device,
|
||||
num_epochs=args.epochs,
|
||||
save_dir=MODEL_SAVE_DIR,
|
||||
model_name='diff_only',
|
||||
modal_key='diff_img'
|
||||
)
|
||||
|
||||
# 训练WNB单模态模型
|
||||
print("训练WNB散点图单模态模型...")
|
||||
wnb_optimizer = optim.AdamW(wnb_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
wnb_model, wnb_history = train_single_modal_model(
|
||||
model=wnb_model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
criterion=criterion,
|
||||
optimizer=wnb_optimizer,
|
||||
device=device,
|
||||
num_epochs=args.epochs,
|
||||
save_dir=MODEL_SAVE_DIR,
|
||||
model_name='wnb_only',
|
||||
modal_key='wnb_img'
|
||||
)
|
||||
|
||||
# 训练多模态模型
|
||||
print("训练多模态融合模型...")
|
||||
multi_optimizer = optim.AdamW(multi_modal_model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
multi_modal_model, multi_history = train_model(
|
||||
model=multi_modal_model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
criterion=criterion,
|
||||
optimizer=multi_optimizer,
|
||||
device=device,
|
||||
num_epochs=args.epochs,
|
||||
save_dir=MODEL_SAVE_DIR,
|
||||
model_name='multi_modal'
|
||||
)
|
||||
|
||||
# 评估并比较模型
|
||||
print("\n评估DIFF散点图单模态模型...")
|
||||
diff_metrics, _, _, _ = evaluate_model(
|
||||
model=diff_model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
class_names=['正常', '异常']
|
||||
)
|
||||
|
||||
print("\n评估WNB散点图单模态模型...")
|
||||
wnb_metrics, _, _, _ = evaluate_model(
|
||||
model=wnb_model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
class_names=['正常', '异常']
|
||||
)
|
||||
|
||||
print("\n评估多模态融合模型...")
|
||||
multi_metrics, _, _, _ = evaluate_model(
|
||||
model=multi_modal_model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
class_names=['正常', '异常']
|
||||
)
|
||||
|
||||
# 比较不同模型的性能
|
||||
compare_models(
|
||||
[diff_metrics, wnb_metrics, multi_metrics],
|
||||
['DIFF散点图', 'WNB散点图', '多模态融合'],
|
||||
save_path=os.path.join(args.output_dir, 'model_comparison.png')
|
||||
)
|
||||
|
||||
# 可视化训练曲线
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
plt.subplot(2, 2, 1)
|
||||
plt.plot(diff_history['train_losses'], label='DIFF训练')
|
||||
plt.plot(wnb_history['train_losses'], label='WNB训练')
|
||||
plt.plot(multi_history['train_losses'], label='多模态训练')
|
||||
plt.title('训练损失')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('损失')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.subplot(2, 2, 2)
|
||||
plt.plot(diff_history['val_losses'], label='DIFF验证')
|
||||
plt.plot(wnb_history['val_losses'], label='WNB验证')
|
||||
plt.plot(multi_history['val_losses'], label='多模态验证')
|
||||
plt.title('验证损失')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('损失')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.subplot(2, 2, 3)
|
||||
plt.plot(diff_history['train_accs'], label='DIFF训练')
|
||||
plt.plot(wnb_history['train_accs'], label='WNB训练')
|
||||
plt.plot(multi_history['train_accs'], label='多模态训练')
|
||||
plt.title('训练准确率')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('准确率')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.subplot(2, 2, 4)
|
||||
plt.plot(diff_history['val_accs'], label='DIFF验证')
|
||||
plt.plot(wnb_history['val_accs'], label='WNB验证')
|
||||
plt.plot(multi_history['val_accs'], label='多模态验证')
|
||||
plt.title('验证准确率')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('准确率')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(args.output_dir, 'all_models_training_curves.png'))
|
||||
plt.show()
|
||||
|
||||
# 保存结果
|
||||
save_results(diff_metrics, os.path.join(args.output_dir, 'diff_model_results.txt'))
|
||||
save_results(wnb_metrics, os.path.join(args.output_dir, 'wnb_model_results.txt'))
|
||||
save_results(multi_metrics, os.path.join(args.output_dir, 'multi_modal_results.txt'))
|
||||
|
||||
elif args.mode == 'evaluate':
|
||||
# 加载预训练的多模态模型
|
||||
print("加载预训练的多模态模型...")
|
||||
model_path = os.path.join(MODEL_SAVE_DIR, 'multi_modal_best.pth')
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
print(f"错误:找不到预训练模型 {model_path}")
|
||||
return
|
||||
|
||||
multi_modal_model = MultiModalFusionModel(
|
||||
img_size=IMG_SIZE,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dim=HIDDEN_DIM,
|
||||
depth=NUM_LAYERS,
|
||||
num_heads=NUM_HEADS,
|
||||
dropout=DROPOUT,
|
||||
num_classes=NUM_CLASSES
|
||||
).to(device)
|
||||
|
||||
multi_modal_model.load_state_dict(torch.load(model_path))
|
||||
|
||||
# 评估模型
|
||||
print("评估多模态融合模型...")
|
||||
multi_modal_metrics, labels, preds, probs = evaluate_model(
|
||||
model=multi_modal_model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
class_names=['正常', '异常']
|
||||
)
|
||||
|
||||
# 绘制ROC曲线
|
||||
plot_roc_curve(
|
||||
labels,
|
||||
probs,
|
||||
save_path=os.path.join(args.output_dir, 'multi_modal_roc_curve.png')
|
||||
)
|
||||
|
||||
# 提取和可视化特征
|
||||
extract_and_visualize_features(
|
||||
model=multi_modal_model,
|
||||
data_loader=val_loader,
|
||||
device=device,
|
||||
save_path=os.path.join(args.output_dir, 'feature_visualization.png')
|
||||
)
|
||||
|
||||
# 保存结果
|
||||
save_results(
|
||||
multi_modal_metrics,
|
||||
os.path.join(args.output_dir, 'multi_modal_results.txt')
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
models/__pycache__/fusion_model.cpython-312.pyc
Normal file
BIN
models/__pycache__/fusion_model.cpython-312.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/fusion_model.cpython-38.pyc
Normal file
BIN
models/__pycache__/fusion_model.cpython-38.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/image_models.cpython-312.pyc
Normal file
BIN
models/__pycache__/image_models.cpython-312.pyc
Normal file
Binary file not shown.
BIN
models/__pycache__/image_models.cpython-38.pyc
Normal file
BIN
models/__pycache__/image_models.cpython-38.pyc
Normal file
Binary file not shown.
116
models/fusion_model.py
Normal file
116
models/fusion_model.py
Normal file
@ -0,0 +1,116 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from models.image_models import VisionTransformer
|
||||
|
||||
class MultiModalFusionModel(nn.Module):
|
||||
"""多模态融合模型,融合DIFF和WNB散点图的特征"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_channels=3,
|
||||
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
|
||||
super().__init__()
|
||||
|
||||
# DIFF散点图特征提取器
|
||||
self.diff_encoder = VisionTransformer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=embed_dim,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout
|
||||
)
|
||||
|
||||
# WNB散点图特征提取器
|
||||
self.wnb_encoder = VisionTransformer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=embed_dim,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout
|
||||
)
|
||||
|
||||
# 特征融合层
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(embed_dim * 2, embed_dim),
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# 分类头
|
||||
self.classifier = nn.Linear(embed_dim, num_classes)
|
||||
|
||||
def forward(self, diff_img, wnb_img):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
diff_img: DIFF散点图 [B, C, H, W]
|
||||
wnb_img: WNB散点图 [B, C, H, W]
|
||||
|
||||
Returns:
|
||||
logits: 分类logits [B, num_classes]
|
||||
"""
|
||||
# 提取特征
|
||||
diff_features = self.diff_encoder(diff_img) # [B, E]
|
||||
wnb_features = self.wnb_encoder(wnb_img) # [B, E]
|
||||
|
||||
# 特征融合
|
||||
combined_features = torch.cat([diff_features, wnb_features], dim=1) # [B, 2*E]
|
||||
fused_features = self.fusion(combined_features) # [B, E]
|
||||
|
||||
# 分类
|
||||
logits = self.classifier(fused_features) # [B, num_classes]
|
||||
|
||||
return logits
|
||||
|
||||
def extract_features(self, diff_img, wnb_img):
|
||||
"""提取各个模态的特征,用于分析"""
|
||||
diff_features = self.diff_encoder(diff_img)
|
||||
wnb_features = self.wnb_encoder(wnb_img)
|
||||
|
||||
return {
|
||||
'diff': diff_features,
|
||||
'wnb': wnb_features
|
||||
}
|
||||
|
||||
|
||||
class SingleModalModel(nn.Module):
|
||||
"""单模态模型,用于对比实验"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_channels=3,
|
||||
embed_dim=768, depth=6, num_heads=12, dropout=0.1, num_classes=2):
|
||||
super().__init__()
|
||||
|
||||
# 图像特征提取器
|
||||
self.encoder = VisionTransformer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=embed_dim,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
dropout=dropout
|
||||
)
|
||||
|
||||
# 分类头
|
||||
self.classifier = nn.Linear(embed_dim, num_classes)
|
||||
|
||||
def forward(self, img):
|
||||
"""
|
||||
前向传播
|
||||
|
||||
Args:
|
||||
img: 输入图像 [B, C, H, W]
|
||||
|
||||
Returns:
|
||||
logits: 分类logits [B, num_classes]
|
||||
"""
|
||||
# 提取特征
|
||||
features = self.encoder(img) # [B, E]
|
||||
|
||||
# 分类
|
||||
logits = self.classifier(features) # [B, num_classes]
|
||||
|
||||
return logits
|
||||
78
models/image_models.py
Normal file
78
models/image_models.py
Normal file
@ -0,0 +1,78 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
||||
import torch.nn.functional as F
|
||||
import torchvision.models as models
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
"""将图像分割为patch并进行embedding"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.n_patches = (img_size // patch_size) ** 2
|
||||
|
||||
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C, H, W]
|
||||
batch_size = x.shape[0]
|
||||
x = self.proj(x) # [B, E, H/P, W/P]
|
||||
x = x.flatten(2) # [B, E, (H/P)*(W/P)]
|
||||
x = x.transpose(1, 2) # [B, (H/P)*(W/P), E]
|
||||
return x
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""基于Transformer的图像特征提取模型"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_channels=3,
|
||||
embed_dim=768, depth=6, num_heads=12, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
# Patch Embedding
|
||||
self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
|
||||
self.n_patches = self.patch_embed.n_patches
|
||||
|
||||
# Position Embedding
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.n_patches + 1, embed_dim))
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
|
||||
# Transformer Encoder
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model=embed_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=embed_dim * 4,
|
||||
dropout=dropout,
|
||||
activation='gelu',
|
||||
batch_first=True
|
||||
)
|
||||
self.transformer = TransformerEncoder(encoder_layer, num_layers=depth)
|
||||
|
||||
# 层归一化
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
# 初始化
|
||||
nn.init.trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.trunc_normal_(self.cls_token, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C, H, W]
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# Patch Embedding: [B, N, E]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
# 添加CLS token
|
||||
cls_token = self.cls_token.expand(batch_size, -1, -1) # [B, 1, E]
|
||||
x = torch.cat([cls_token, x], dim=1) # [B, N+1, E]
|
||||
|
||||
# 添加Position Embedding
|
||||
x = x + self.pos_embed
|
||||
|
||||
# Transformer Encoder
|
||||
x = self.transformer(x)
|
||||
|
||||
# 提取CLS token作为整个图像的特征
|
||||
x = x[:, 0] # [B, E]
|
||||
|
||||
return x
|
||||
BIN
training/__pycache__/train.cpython-312.pyc
Normal file
BIN
training/__pycache__/train.cpython-312.pyc
Normal file
Binary file not shown.
BIN
training/__pycache__/train.cpython-38.pyc
Normal file
BIN
training/__pycache__/train.cpython-38.pyc
Normal file
Binary file not shown.
BIN
training/__pycache__/utils.cpython-312.pyc
Normal file
BIN
training/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
training/__pycache__/utils.cpython-38.pyc
Normal file
BIN
training/__pycache__/utils.cpython-38.pyc
Normal file
Binary file not shown.
312
training/train.py
Normal file
312
training/train.py
Normal file
@ -0,0 +1,312 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import numpy as np
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
from training.utils import AverageMeter, EarlyStopping, plot_training_curves, compute_metrics
|
||||
|
||||
def train_epoch(model, train_loader, criterion, optimizer, device):
|
||||
"""训练一个epoch"""
|
||||
model.train()
|
||||
losses = AverageMeter()
|
||||
acc = AverageMeter()
|
||||
|
||||
# 进度条
|
||||
pbar = tqdm(train_loader, desc='训练')
|
||||
|
||||
for batch in pbar:
|
||||
# 获取数据和标签
|
||||
diff_imgs = batch['diff_img'].to(device)
|
||||
wnb_imgs = batch['wnb_img'].to(device)
|
||||
labels = batch['label'].to(device)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(diff_imgs, wnb_imgs)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 计算准确率
|
||||
_, preds = torch.max(outputs, 1)
|
||||
batch_acc = (preds == labels).float().mean()
|
||||
|
||||
# 更新统计
|
||||
losses.update(loss.item(), labels.size(0))
|
||||
acc.update(batch_acc.item(), labels.size(0))
|
||||
|
||||
# 更新进度条
|
||||
pbar.set_postfix({
|
||||
'loss': losses.avg,
|
||||
'acc': acc.avg
|
||||
})
|
||||
|
||||
return losses.avg, acc.avg
|
||||
|
||||
|
||||
def validate(model, val_loader, criterion, device):
|
||||
"""验证模型"""
|
||||
model.eval()
|
||||
losses = AverageMeter()
|
||||
acc = AverageMeter()
|
||||
|
||||
all_labels = []
|
||||
all_preds = []
|
||||
all_probs = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(val_loader, desc='验证'):
|
||||
# 获取数据和标签
|
||||
diff_imgs = batch['diff_img'].to(device)
|
||||
wnb_imgs = batch['wnb_img'].to(device)
|
||||
labels = batch['label'].to(device)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(diff_imgs, wnb_imgs)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# 计算准确率
|
||||
probs = torch.softmax(outputs, dim=1)
|
||||
_, preds = torch.max(outputs, 1)
|
||||
batch_acc = (preds == labels).float().mean()
|
||||
|
||||
# 更新统计
|
||||
losses.update(loss.item(), labels.size(0))
|
||||
acc.update(batch_acc.item(), labels.size(0))
|
||||
|
||||
# 保存预测结果用于计算指标
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
all_preds.extend(preds.cpu().numpy())
|
||||
all_probs.extend(probs.cpu().numpy())
|
||||
|
||||
# 计算其他评估指标
|
||||
metrics = compute_metrics(all_labels, all_preds, all_probs)
|
||||
metrics['loss'] = losses.avg
|
||||
metrics['accuracy'] = acc.avg
|
||||
|
||||
return losses.avg, acc.avg, metrics
|
||||
|
||||
|
||||
def train_model(model, train_loader, val_loader, criterion, optimizer, device,
|
||||
num_epochs=50, save_dir='./models', model_name='model'):
|
||||
"""训练模型"""
|
||||
# 保存最佳模型的路径
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
|
||||
|
||||
# 初始化早停
|
||||
early_stopping = EarlyStopping(patience=10, path=best_model_path)
|
||||
|
||||
# 跟踪训练历史
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
train_accs = []
|
||||
val_accs = []
|
||||
best_val_acc = 0.0
|
||||
|
||||
# 开始训练
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f'\nEpoch {epoch+1}/{num_epochs}')
|
||||
print('-' * 20)
|
||||
|
||||
# 训练
|
||||
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
|
||||
train_losses.append(train_loss)
|
||||
train_accs.append(train_acc)
|
||||
|
||||
# 验证
|
||||
val_loss, val_acc, val_metrics = validate(model, val_loader, criterion, device)
|
||||
val_losses.append(val_loss)
|
||||
val_accs.append(val_acc)
|
||||
|
||||
# 打印当前epoch的结果
|
||||
print(f'训练损失: {train_loss:.4f} 训练准确率: {train_acc:.4f}')
|
||||
print(f'验证损失: {val_loss:.4f} 验证准确率: {val_acc:.4f}')
|
||||
print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
|
||||
|
||||
# 检查是否为最佳验证准确率
|
||||
if val_acc > best_val_acc:
|
||||
best_val_acc = val_acc
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'val_acc': val_acc,
|
||||
'val_metrics': val_metrics
|
||||
}, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
|
||||
print(f'保存新的最佳模型,验证准确率: {val_acc:.4f}')
|
||||
|
||||
# 早停检查
|
||||
early_stopping(val_loss, model)
|
||||
if early_stopping.early_stop:
|
||||
print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
|
||||
break
|
||||
|
||||
# 计算总训练时间
|
||||
total_time = time.time() - start_time
|
||||
print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
|
||||
|
||||
# 加载最佳模型
|
||||
model.load_state_dict(torch.load(best_model_path))
|
||||
|
||||
return model, {
|
||||
'train_losses': train_losses,
|
||||
'val_losses': val_losses,
|
||||
'train_accs': train_accs,
|
||||
'val_accs': val_accs,
|
||||
'best_val_acc': best_val_acc,
|
||||
'total_time': total_time
|
||||
}
|
||||
|
||||
|
||||
def train_single_modal_model(model, train_loader, val_loader, criterion, optimizer, device,
|
||||
num_epochs=50, save_dir='./models', model_name='single_modal', modal_key='diff_img'):
|
||||
"""训练单模态模型"""
|
||||
# 保存最佳模型的路径
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
|
||||
|
||||
# 初始化早停
|
||||
early_stopping = EarlyStopping(patience=10, path=best_model_path)
|
||||
|
||||
# 跟踪训练历史
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
train_accs = []
|
||||
val_accs = []
|
||||
best_val_acc = 0.0
|
||||
|
||||
# 开始训练
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f'\nEpoch {epoch+1}/{num_epochs}')
|
||||
print('-' * 20)
|
||||
|
||||
# 训练
|
||||
model.train()
|
||||
train_loss = AverageMeter()
|
||||
train_acc = AverageMeter()
|
||||
|
||||
pbar = tqdm(train_loader, desc='训练')
|
||||
for batch in pbar:
|
||||
# 获取数据和标签
|
||||
imgs = batch[modal_key].to(device)
|
||||
labels = batch['label'].to(device)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(imgs)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# 反向传播和优化
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 计算准确率
|
||||
_, preds = torch.max(outputs, 1)
|
||||
batch_acc = (preds == labels).float().mean()
|
||||
|
||||
# 更新统计
|
||||
train_loss.update(loss.item(), labels.size(0))
|
||||
train_acc.update(batch_acc.item(), labels.size(0))
|
||||
|
||||
# 更新进度条
|
||||
pbar.set_postfix({
|
||||
'loss': train_loss.avg,
|
||||
'acc': train_acc.avg
|
||||
})
|
||||
|
||||
train_losses.append(train_loss.avg)
|
||||
train_accs.append(train_acc.avg)
|
||||
|
||||
# 验证
|
||||
model.eval()
|
||||
val_loss = AverageMeter()
|
||||
val_acc = AverageMeter()
|
||||
|
||||
all_labels = []
|
||||
all_preds = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(val_loader, desc='验证'):
|
||||
# 获取数据和标签
|
||||
imgs = batch[modal_key].to(device)
|
||||
labels = batch['label'].to(device)
|
||||
|
||||
# 前向传播
|
||||
outputs = model(imgs)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# 计算准确率
|
||||
_, preds = torch.max(outputs, 1)
|
||||
batch_acc = (preds == labels).float().mean()
|
||||
|
||||
# 更新统计
|
||||
val_loss.update(loss.item(), labels.size(0))
|
||||
val_acc.update(batch_acc.item(), labels.size(0))
|
||||
|
||||
# 保存预测结果用于计算指标
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
all_preds.extend(preds.cpu().numpy())
|
||||
|
||||
val_losses.append(val_loss.avg)
|
||||
val_accs.append(val_acc.avg)
|
||||
|
||||
# 计算其他评估指标
|
||||
val_metrics = compute_metrics(all_labels, all_preds)
|
||||
|
||||
# 打印当前epoch的结果
|
||||
print(f'训练损失: {train_loss.avg:.4f} 训练准确率: {train_acc.avg:.4f}')
|
||||
print(f'验证损失: {val_loss.avg:.4f} 验证准确率: {val_acc.avg:.4f}')
|
||||
print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
|
||||
|
||||
# 检查是否为最佳验证准确率
|
||||
if val_acc.avg > best_val_acc:
|
||||
best_val_acc = val_acc.avg
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'val_acc': val_acc.avg,
|
||||
'val_metrics': val_metrics
|
||||
}, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
|
||||
print(f'保存新的最佳模型,验证准确率: {val_acc.avg:.4f}')
|
||||
|
||||
# 早停检查
|
||||
early_stopping(val_loss.avg, model)
|
||||
if early_stopping.early_stop:
|
||||
print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
|
||||
break
|
||||
|
||||
# 计算总训练时间
|
||||
total_time = time.time() - start_time
|
||||
print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
|
||||
|
||||
# 加载最佳模型
|
||||
model.load_state_dict(torch.load(best_model_path))
|
||||
|
||||
return model, {
|
||||
'train_losses': train_losses,
|
||||
'val_losses': val_losses,
|
||||
'train_accs': train_accs,
|
||||
'val_accs': val_accs,
|
||||
'best_val_acc': best_val_acc,
|
||||
'total_time': total_time
|
||||
}
|
||||
126
training/utils.py
Normal file
126
training/utils.py
Normal file
@ -0,0 +1,126 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||||
import seaborn as sns
|
||||
import os
|
||||
|
||||
class AverageMeter:
|
||||
"""跟踪平均值和当前值"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
"""提前停止训练,避免过拟合"""
|
||||
def __init__(self, patience=7, delta=0, path='checkpoint.pt'):
|
||||
self.patience = patience
|
||||
self.delta = delta
|
||||
self.path = path
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
|
||||
def __call__(self, val_loss, model):
|
||||
score = -val_loss
|
||||
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
elif score < self.best_score + self.delta:
|
||||
self.counter += 1
|
||||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(val_loss, model)
|
||||
self.counter = 0
|
||||
|
||||
def save_checkpoint(self, val_loss, model):
|
||||
torch.save(model.state_dict(), self.path)
|
||||
print(f'验证损失降低 ({self.best_score:.6f} --> {-val_loss:.6f}). 保存模型...')
|
||||
|
||||
|
||||
def plot_training_curves(train_losses, val_losses, train_accs, val_accs, save_path=None):
|
||||
"""绘制训练和验证的损失与准确率曲线"""
|
||||
plt.figure(figsize=(12, 5))
|
||||
|
||||
plt.subplot(1, 2, 1)
|
||||
plt.plot(train_losses, label='训练损失')
|
||||
plt.plot(val_losses, label='验证损失')
|
||||
plt.title('损失曲线')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('损失')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.plot(train_accs, label='训练准确率')
|
||||
plt.plot(val_accs, label='验证准确率')
|
||||
plt.title('准确率曲线')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('准确率')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_confusion_matrix(y_true, y_pred, class_names=None, save_path=None):
|
||||
"""绘制混淆矩阵"""
|
||||
cm = confusion_matrix(y_true, y_pred)
|
||||
|
||||
plt.figure(figsize=(8, 6))
|
||||
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||||
xticklabels=class_names if class_names else "auto",
|
||||
yticklabels=class_names if class_names else "auto")
|
||||
plt.title('混淆矩阵')
|
||||
plt.xlabel('预测标签')
|
||||
plt.ylabel('真实标签')
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def compute_metrics(y_true, y_pred, y_proba=None):
|
||||
"""计算各种评估指标"""
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
|
||||
recall = recall_score(y_true, y_pred, average='binary')
|
||||
f1 = f1_score(y_true, y_pred, average='binary')
|
||||
|
||||
metrics = {
|
||||
'accuracy': accuracy,
|
||||
'precision': precision,
|
||||
'recall': recall,
|
||||
'f1': f1
|
||||
}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def save_results(results, filename):
|
||||
"""保存结果到文本文件"""
|
||||
with open(filename, 'w') as f:
|
||||
for key, value in results.items():
|
||||
f.write(f"{key}: {value}\n")
|
||||
Loading…
Reference in New Issue
Block a user