312 lines
10 KiB
Python
312 lines
10 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
import numpy as np
|
|
import os
|
|
from tqdm import tqdm
|
|
import time
|
|
from training.utils import AverageMeter, EarlyStopping, plot_training_curves, compute_metrics
|
|
|
|
def train_epoch(model, train_loader, criterion, optimizer, device):
|
|
"""训练一个epoch"""
|
|
model.train()
|
|
losses = AverageMeter()
|
|
acc = AverageMeter()
|
|
|
|
# 进度条
|
|
pbar = tqdm(train_loader, desc='训练')
|
|
|
|
for batch in pbar:
|
|
# 获取数据和标签
|
|
diff_imgs = batch['diff_img'].to(device)
|
|
wnb_imgs = batch['wnb_img'].to(device)
|
|
labels = batch['label'].to(device)
|
|
|
|
# 前向传播
|
|
outputs = model(diff_imgs, wnb_imgs)
|
|
|
|
# 计算损失
|
|
loss = criterion(outputs, labels)
|
|
|
|
# 反向传播和优化
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# 计算准确率
|
|
_, preds = torch.max(outputs, 1)
|
|
batch_acc = (preds == labels).float().mean()
|
|
|
|
# 更新统计
|
|
losses.update(loss.item(), labels.size(0))
|
|
acc.update(batch_acc.item(), labels.size(0))
|
|
|
|
# 更新进度条
|
|
pbar.set_postfix({
|
|
'loss': losses.avg,
|
|
'acc': acc.avg
|
|
})
|
|
|
|
return losses.avg, acc.avg
|
|
|
|
|
|
def validate(model, val_loader, criterion, device):
|
|
"""验证模型"""
|
|
model.eval()
|
|
losses = AverageMeter()
|
|
acc = AverageMeter()
|
|
|
|
all_labels = []
|
|
all_preds = []
|
|
all_probs = []
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(val_loader, desc='验证'):
|
|
# 获取数据和标签
|
|
diff_imgs = batch['diff_img'].to(device)
|
|
wnb_imgs = batch['wnb_img'].to(device)
|
|
labels = batch['label'].to(device)
|
|
|
|
# 前向传播
|
|
outputs = model(diff_imgs, wnb_imgs)
|
|
|
|
# 计算损失
|
|
loss = criterion(outputs, labels)
|
|
|
|
# 计算准确率
|
|
probs = torch.softmax(outputs, dim=1)
|
|
_, preds = torch.max(outputs, 1)
|
|
batch_acc = (preds == labels).float().mean()
|
|
|
|
# 更新统计
|
|
losses.update(loss.item(), labels.size(0))
|
|
acc.update(batch_acc.item(), labels.size(0))
|
|
|
|
# 保存预测结果用于计算指标
|
|
all_labels.extend(labels.cpu().numpy())
|
|
all_preds.extend(preds.cpu().numpy())
|
|
all_probs.extend(probs.cpu().numpy())
|
|
|
|
# 计算其他评估指标
|
|
metrics = compute_metrics(all_labels, all_preds, all_probs)
|
|
metrics['loss'] = losses.avg
|
|
metrics['accuracy'] = acc.avg
|
|
|
|
return losses.avg, acc.avg, metrics
|
|
|
|
|
|
def train_model(model, train_loader, val_loader, criterion, optimizer, device,
|
|
num_epochs=50, save_dir='./models', model_name='model'):
|
|
"""训练模型"""
|
|
# 保存最佳模型的路径
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
|
|
|
|
# 初始化早停
|
|
early_stopping = EarlyStopping(patience=10, path=best_model_path)
|
|
|
|
# 跟踪训练历史
|
|
train_losses = []
|
|
val_losses = []
|
|
train_accs = []
|
|
val_accs = []
|
|
best_val_acc = 0.0
|
|
|
|
# 开始训练
|
|
start_time = time.time()
|
|
|
|
for epoch in range(num_epochs):
|
|
print(f'\nEpoch {epoch+1}/{num_epochs}')
|
|
print('-' * 20)
|
|
|
|
# 训练
|
|
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
|
|
train_losses.append(train_loss)
|
|
train_accs.append(train_acc)
|
|
|
|
# 验证
|
|
val_loss, val_acc, val_metrics = validate(model, val_loader, criterion, device)
|
|
val_losses.append(val_loss)
|
|
val_accs.append(val_acc)
|
|
|
|
# 打印当前epoch的结果
|
|
print(f'训练损失: {train_loss:.4f} 训练准确率: {train_acc:.4f}')
|
|
print(f'验证损失: {val_loss:.4f} 验证准确率: {val_acc:.4f}')
|
|
print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
|
|
|
|
# 检查是否为最佳验证准确率
|
|
if val_acc > best_val_acc:
|
|
best_val_acc = val_acc
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'val_acc': val_acc,
|
|
'val_metrics': val_metrics
|
|
}, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
|
|
print(f'保存新的最佳模型,验证准确率: {val_acc:.4f}')
|
|
|
|
# 早停检查
|
|
early_stopping(val_loss, model)
|
|
if early_stopping.early_stop:
|
|
print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
|
|
break
|
|
|
|
# 计算总训练时间
|
|
total_time = time.time() - start_time
|
|
print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
|
|
|
|
# 加载最佳模型
|
|
model.load_state_dict(torch.load(best_model_path))
|
|
|
|
return model, {
|
|
'train_losses': train_losses,
|
|
'val_losses': val_losses,
|
|
'train_accs': train_accs,
|
|
'val_accs': val_accs,
|
|
'best_val_acc': best_val_acc,
|
|
'total_time': total_time
|
|
}
|
|
|
|
|
|
def train_single_modal_model(model, train_loader, val_loader, criterion, optimizer, device,
|
|
num_epochs=50, save_dir='./models', model_name='single_modal', modal_key='diff_img'):
|
|
"""训练单模态模型"""
|
|
# 保存最佳模型的路径
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir)
|
|
best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
|
|
|
|
# 初始化早停
|
|
early_stopping = EarlyStopping(patience=10, path=best_model_path)
|
|
|
|
# 跟踪训练历史
|
|
train_losses = []
|
|
val_losses = []
|
|
train_accs = []
|
|
val_accs = []
|
|
best_val_acc = 0.0
|
|
|
|
# 开始训练
|
|
start_time = time.time()
|
|
|
|
for epoch in range(num_epochs):
|
|
print(f'\nEpoch {epoch+1}/{num_epochs}')
|
|
print('-' * 20)
|
|
|
|
# 训练
|
|
model.train()
|
|
train_loss = AverageMeter()
|
|
train_acc = AverageMeter()
|
|
|
|
pbar = tqdm(train_loader, desc='训练')
|
|
for batch in pbar:
|
|
# 获取数据和标签
|
|
imgs = batch[modal_key].to(device)
|
|
labels = batch['label'].to(device)
|
|
|
|
# 前向传播
|
|
outputs = model(imgs)
|
|
|
|
# 计算损失
|
|
loss = criterion(outputs, labels)
|
|
|
|
# 反向传播和优化
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# 计算准确率
|
|
_, preds = torch.max(outputs, 1)
|
|
batch_acc = (preds == labels).float().mean()
|
|
|
|
# 更新统计
|
|
train_loss.update(loss.item(), labels.size(0))
|
|
train_acc.update(batch_acc.item(), labels.size(0))
|
|
|
|
# 更新进度条
|
|
pbar.set_postfix({
|
|
'loss': train_loss.avg,
|
|
'acc': train_acc.avg
|
|
})
|
|
|
|
train_losses.append(train_loss.avg)
|
|
train_accs.append(train_acc.avg)
|
|
|
|
# 验证
|
|
model.eval()
|
|
val_loss = AverageMeter()
|
|
val_acc = AverageMeter()
|
|
|
|
all_labels = []
|
|
all_preds = []
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(val_loader, desc='验证'):
|
|
# 获取数据和标签
|
|
imgs = batch[modal_key].to(device)
|
|
labels = batch['label'].to(device)
|
|
|
|
# 前向传播
|
|
outputs = model(imgs)
|
|
|
|
# 计算损失
|
|
loss = criterion(outputs, labels)
|
|
|
|
# 计算准确率
|
|
_, preds = torch.max(outputs, 1)
|
|
batch_acc = (preds == labels).float().mean()
|
|
|
|
# 更新统计
|
|
val_loss.update(loss.item(), labels.size(0))
|
|
val_acc.update(batch_acc.item(), labels.size(0))
|
|
|
|
# 保存预测结果用于计算指标
|
|
all_labels.extend(labels.cpu().numpy())
|
|
all_preds.extend(preds.cpu().numpy())
|
|
|
|
val_losses.append(val_loss.avg)
|
|
val_accs.append(val_acc.avg)
|
|
|
|
# 计算其他评估指标
|
|
val_metrics = compute_metrics(all_labels, all_preds)
|
|
|
|
# 打印当前epoch的结果
|
|
print(f'训练损失: {train_loss.avg:.4f} 训练准确率: {train_acc.avg:.4f}')
|
|
print(f'验证损失: {val_loss.avg:.4f} 验证准确率: {val_acc.avg:.4f}')
|
|
print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
|
|
|
|
# 检查是否为最佳验证准确率
|
|
if val_acc.avg > best_val_acc:
|
|
best_val_acc = val_acc.avg
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'val_acc': val_acc.avg,
|
|
'val_metrics': val_metrics
|
|
}, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
|
|
print(f'保存新的最佳模型,验证准确率: {val_acc.avg:.4f}')
|
|
|
|
# 早停检查
|
|
early_stopping(val_loss.avg, model)
|
|
if early_stopping.early_stop:
|
|
print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
|
|
break
|
|
|
|
# 计算总训练时间
|
|
total_time = time.time() - start_time
|
|
print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
|
|
|
|
# 加载最佳模型
|
|
model.load_state_dict(torch.load(best_model_path))
|
|
|
|
return model, {
|
|
'train_losses': train_losses,
|
|
'val_losses': val_losses,
|
|
'train_accs': train_accs,
|
|
'val_accs': val_accs,
|
|
'best_val_acc': best_val_acc,
|
|
'total_time': total_time
|
|
} |