import torch import torch.nn as nn import torch.optim as optim import numpy as np import os from tqdm import tqdm import time from training.utils import AverageMeter, EarlyStopping, plot_training_curves, compute_metrics def train_epoch(model, train_loader, criterion, optimizer, device): """训练一个epoch""" model.train() losses = AverageMeter() acc = AverageMeter() # 进度条 pbar = tqdm(train_loader, desc='训练') for batch in pbar: # 获取数据和标签 diff_imgs = batch['diff_img'].to(device) wnb_imgs = batch['wnb_img'].to(device) labels = batch['label'].to(device) # 前向传播 outputs = model(diff_imgs, wnb_imgs) # 计算损失 loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() # 计算准确率 _, preds = torch.max(outputs, 1) batch_acc = (preds == labels).float().mean() # 更新统计 losses.update(loss.item(), labels.size(0)) acc.update(batch_acc.item(), labels.size(0)) # 更新进度条 pbar.set_postfix({ 'loss': losses.avg, 'acc': acc.avg }) return losses.avg, acc.avg def validate(model, val_loader, criterion, device): """验证模型""" model.eval() losses = AverageMeter() acc = AverageMeter() all_labels = [] all_preds = [] all_probs = [] with torch.no_grad(): for batch in tqdm(val_loader, desc='验证'): # 获取数据和标签 diff_imgs = batch['diff_img'].to(device) wnb_imgs = batch['wnb_img'].to(device) labels = batch['label'].to(device) # 前向传播 outputs = model(diff_imgs, wnb_imgs) # 计算损失 loss = criterion(outputs, labels) # 计算准确率 probs = torch.softmax(outputs, dim=1) _, preds = torch.max(outputs, 1) batch_acc = (preds == labels).float().mean() # 更新统计 losses.update(loss.item(), labels.size(0)) acc.update(batch_acc.item(), labels.size(0)) # 保存预测结果用于计算指标 all_labels.extend(labels.cpu().numpy()) all_preds.extend(preds.cpu().numpy()) all_probs.extend(probs.cpu().numpy()) # 计算其他评估指标 metrics = compute_metrics(all_labels, all_preds, all_probs) metrics['loss'] = losses.avg metrics['accuracy'] = acc.avg return losses.avg, acc.avg, metrics def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50, save_dir='./models', model_name='model'): """训练模型""" # 保存最佳模型的路径 if not os.path.exists(save_dir): os.makedirs(save_dir) best_model_path = os.path.join(save_dir, f'{model_name}_best.pth') # 初始化早停 early_stopping = EarlyStopping(patience=10, path=best_model_path) # 跟踪训练历史 train_losses = [] val_losses = [] train_accs = [] val_accs = [] best_val_acc = 0.0 # 开始训练 start_time = time.time() for epoch in range(num_epochs): print(f'\nEpoch {epoch+1}/{num_epochs}') print('-' * 20) # 训练 train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device) train_losses.append(train_loss) train_accs.append(train_acc) # 验证 val_loss, val_acc, val_metrics = validate(model, val_loader, criterion, device) val_losses.append(val_loss) val_accs.append(val_acc) # 打印当前epoch的结果 print(f'训练损失: {train_loss:.4f} 训练准确率: {train_acc:.4f}') print(f'验证损失: {val_loss:.4f} 验证准确率: {val_acc:.4f}') print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}') # 检查是否为最佳验证准确率 if val_acc > best_val_acc: best_val_acc = val_acc torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_acc': val_acc, 'val_metrics': val_metrics }, os.path.join(save_dir, f'{model_name}_best_acc.pth')) print(f'保存新的最佳模型,验证准确率: {val_acc:.4f}') # 早停检查 early_stopping(val_loss, model) if early_stopping.early_stop: print(f"早停! 在第 {epoch+1} 个Epoch停止训练") break # 计算总训练时间 total_time = time.time() - start_time print(f'训练完成! 总用时: {total_time/60:.2f} 分钟') # 加载最佳模型 model.load_state_dict(torch.load(best_model_path)) return model, { 'train_losses': train_losses, 'val_losses': val_losses, 'train_accs': train_accs, 'val_accs': val_accs, 'best_val_acc': best_val_acc, 'total_time': total_time } def train_single_modal_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50, save_dir='./models', model_name='single_modal', modal_key='diff_img'): """训练单模态模型""" # 保存最佳模型的路径 if not os.path.exists(save_dir): os.makedirs(save_dir) best_model_path = os.path.join(save_dir, f'{model_name}_best.pth') # 初始化早停 early_stopping = EarlyStopping(patience=10, path=best_model_path) # 跟踪训练历史 train_losses = [] val_losses = [] train_accs = [] val_accs = [] best_val_acc = 0.0 # 开始训练 start_time = time.time() for epoch in range(num_epochs): print(f'\nEpoch {epoch+1}/{num_epochs}') print('-' * 20) # 训练 model.train() train_loss = AverageMeter() train_acc = AverageMeter() pbar = tqdm(train_loader, desc='训练') for batch in pbar: # 获取数据和标签 imgs = batch[modal_key].to(device) labels = batch['label'].to(device) # 前向传播 outputs = model(imgs) # 计算损失 loss = criterion(outputs, labels) # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() # 计算准确率 _, preds = torch.max(outputs, 1) batch_acc = (preds == labels).float().mean() # 更新统计 train_loss.update(loss.item(), labels.size(0)) train_acc.update(batch_acc.item(), labels.size(0)) # 更新进度条 pbar.set_postfix({ 'loss': train_loss.avg, 'acc': train_acc.avg }) train_losses.append(train_loss.avg) train_accs.append(train_acc.avg) # 验证 model.eval() val_loss = AverageMeter() val_acc = AverageMeter() all_labels = [] all_preds = [] with torch.no_grad(): for batch in tqdm(val_loader, desc='验证'): # 获取数据和标签 imgs = batch[modal_key].to(device) labels = batch['label'].to(device) # 前向传播 outputs = model(imgs) # 计算损失 loss = criterion(outputs, labels) # 计算准确率 _, preds = torch.max(outputs, 1) batch_acc = (preds == labels).float().mean() # 更新统计 val_loss.update(loss.item(), labels.size(0)) val_acc.update(batch_acc.item(), labels.size(0)) # 保存预测结果用于计算指标 all_labels.extend(labels.cpu().numpy()) all_preds.extend(preds.cpu().numpy()) val_losses.append(val_loss.avg) val_accs.append(val_acc.avg) # 计算其他评估指标 val_metrics = compute_metrics(all_labels, all_preds) # 打印当前epoch的结果 print(f'训练损失: {train_loss.avg:.4f} 训练准确率: {train_acc.avg:.4f}') print(f'验证损失: {val_loss.avg:.4f} 验证准确率: {val_acc.avg:.4f}') print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}') # 检查是否为最佳验证准确率 if val_acc.avg > best_val_acc: best_val_acc = val_acc.avg torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_acc': val_acc.avg, 'val_metrics': val_metrics }, os.path.join(save_dir, f'{model_name}_best_acc.pth')) print(f'保存新的最佳模型,验证准确率: {val_acc.avg:.4f}') # 早停检查 early_stopping(val_loss.avg, model) if early_stopping.early_stop: print(f"早停! 在第 {epoch+1} 个Epoch停止训练") break # 计算总训练时间 total_time = time.time() - start_time print(f'训练完成! 总用时: {total_time/60:.2f} 分钟') # 加载最佳模型 model.load_state_dict(torch.load(best_model_path)) return model, { 'train_losses': train_losses, 'val_losses': val_losses, 'train_accs': train_accs, 'val_accs': val_accs, 'best_val_acc': best_val_acc, 'total_time': total_time }