312 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			312 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import torch.nn as nn
 | 
						|
import torch.optim as optim
 | 
						|
import numpy as np
 | 
						|
import os
 | 
						|
from tqdm import tqdm
 | 
						|
import time
 | 
						|
from training.utils import AverageMeter, EarlyStopping, plot_training_curves, compute_metrics
 | 
						|
 | 
						|
def train_epoch(model, train_loader, criterion, optimizer, device):
 | 
						|
    """训练一个epoch"""
 | 
						|
    model.train()
 | 
						|
    losses = AverageMeter()
 | 
						|
    acc = AverageMeter()
 | 
						|
    
 | 
						|
    # 进度条
 | 
						|
    pbar = tqdm(train_loader, desc='训练')
 | 
						|
    
 | 
						|
    for batch in pbar:
 | 
						|
        # 获取数据和标签
 | 
						|
        diff_imgs = batch['diff_img'].to(device)
 | 
						|
        wnb_imgs = batch['wnb_img'].to(device)
 | 
						|
        labels = batch['label'].to(device)
 | 
						|
        
 | 
						|
        # 前向传播
 | 
						|
        outputs = model(diff_imgs, wnb_imgs)
 | 
						|
        
 | 
						|
        # 计算损失
 | 
						|
        loss = criterion(outputs, labels)
 | 
						|
        
 | 
						|
        # 反向传播和优化
 | 
						|
        optimizer.zero_grad()
 | 
						|
        loss.backward()
 | 
						|
        optimizer.step()
 | 
						|
        
 | 
						|
        # 计算准确率
 | 
						|
        _, preds = torch.max(outputs, 1)
 | 
						|
        batch_acc = (preds == labels).float().mean()
 | 
						|
        
 | 
						|
        # 更新统计
 | 
						|
        losses.update(loss.item(), labels.size(0))
 | 
						|
        acc.update(batch_acc.item(), labels.size(0))
 | 
						|
        
 | 
						|
        # 更新进度条
 | 
						|
        pbar.set_postfix({
 | 
						|
            'loss': losses.avg,
 | 
						|
            'acc': acc.avg
 | 
						|
        })
 | 
						|
    
 | 
						|
    return losses.avg, acc.avg
 | 
						|
 | 
						|
 | 
						|
def validate(model, val_loader, criterion, device):
 | 
						|
    """验证模型"""
 | 
						|
    model.eval()
 | 
						|
    losses = AverageMeter()
 | 
						|
    acc = AverageMeter()
 | 
						|
    
 | 
						|
    all_labels = []
 | 
						|
    all_preds = []
 | 
						|
    all_probs = []
 | 
						|
    
 | 
						|
    with torch.no_grad():
 | 
						|
        for batch in tqdm(val_loader, desc='验证'):
 | 
						|
            # 获取数据和标签
 | 
						|
            diff_imgs = batch['diff_img'].to(device)
 | 
						|
            wnb_imgs = batch['wnb_img'].to(device)
 | 
						|
            labels = batch['label'].to(device)
 | 
						|
            
 | 
						|
            # 前向传播
 | 
						|
            outputs = model(diff_imgs, wnb_imgs)
 | 
						|
            
 | 
						|
            # 计算损失
 | 
						|
            loss = criterion(outputs, labels)
 | 
						|
            
 | 
						|
            # 计算准确率
 | 
						|
            probs = torch.softmax(outputs, dim=1)
 | 
						|
            _, preds = torch.max(outputs, 1)
 | 
						|
            batch_acc = (preds == labels).float().mean()
 | 
						|
            
 | 
						|
            # 更新统计
 | 
						|
            losses.update(loss.item(), labels.size(0))
 | 
						|
            acc.update(batch_acc.item(), labels.size(0))
 | 
						|
            
 | 
						|
            # 保存预测结果用于计算指标
 | 
						|
            all_labels.extend(labels.cpu().numpy())
 | 
						|
            all_preds.extend(preds.cpu().numpy())
 | 
						|
            all_probs.extend(probs.cpu().numpy())
 | 
						|
    
 | 
						|
    # 计算其他评估指标
 | 
						|
    metrics = compute_metrics(all_labels, all_preds, all_probs)
 | 
						|
    metrics['loss'] = losses.avg
 | 
						|
    metrics['accuracy'] = acc.avg
 | 
						|
    
 | 
						|
    return losses.avg, acc.avg, metrics
 | 
						|
 | 
						|
 | 
						|
def train_model(model, train_loader, val_loader, criterion, optimizer, device, 
 | 
						|
                num_epochs=50, save_dir='./models', model_name='model'):
 | 
						|
    """训练模型"""
 | 
						|
    # 保存最佳模型的路径
 | 
						|
    if not os.path.exists(save_dir):
 | 
						|
        os.makedirs(save_dir)
 | 
						|
    best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
 | 
						|
    
 | 
						|
    # 初始化早停
 | 
						|
    early_stopping = EarlyStopping(patience=10, path=best_model_path)
 | 
						|
    
 | 
						|
    # 跟踪训练历史
 | 
						|
    train_losses = []
 | 
						|
    val_losses = []
 | 
						|
    train_accs = []
 | 
						|
    val_accs = []
 | 
						|
    best_val_acc = 0.0
 | 
						|
    
 | 
						|
    # 开始训练
 | 
						|
    start_time = time.time()
 | 
						|
    
 | 
						|
    for epoch in range(num_epochs):
 | 
						|
        print(f'\nEpoch {epoch+1}/{num_epochs}')
 | 
						|
        print('-' * 20)
 | 
						|
        
 | 
						|
        # 训练
 | 
						|
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
 | 
						|
        train_losses.append(train_loss)
 | 
						|
        train_accs.append(train_acc)
 | 
						|
        
 | 
						|
        # 验证
 | 
						|
        val_loss, val_acc, val_metrics = validate(model, val_loader, criterion, device)
 | 
						|
        val_losses.append(val_loss)
 | 
						|
        val_accs.append(val_acc)
 | 
						|
        
 | 
						|
        # 打印当前epoch的结果
 | 
						|
        print(f'训练损失: {train_loss:.4f} 训练准确率: {train_acc:.4f}')
 | 
						|
        print(f'验证损失: {val_loss:.4f} 验证准确率: {val_acc:.4f}')
 | 
						|
        print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
 | 
						|
        
 | 
						|
        # 检查是否为最佳验证准确率
 | 
						|
        if val_acc > best_val_acc:
 | 
						|
            best_val_acc = val_acc
 | 
						|
            torch.save({
 | 
						|
                'epoch': epoch,
 | 
						|
                'model_state_dict': model.state_dict(),
 | 
						|
                'optimizer_state_dict': optimizer.state_dict(),
 | 
						|
                'val_acc': val_acc,
 | 
						|
                'val_metrics': val_metrics
 | 
						|
            }, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
 | 
						|
            print(f'保存新的最佳模型,验证准确率: {val_acc:.4f}')
 | 
						|
        
 | 
						|
        # 早停检查
 | 
						|
        early_stopping(val_loss, model)
 | 
						|
        if early_stopping.early_stop:
 | 
						|
            print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
 | 
						|
            break
 | 
						|
    
 | 
						|
    # 计算总训练时间
 | 
						|
    total_time = time.time() - start_time
 | 
						|
    print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
 | 
						|
    
 | 
						|
    # 加载最佳模型
 | 
						|
    model.load_state_dict(torch.load(best_model_path))
 | 
						|
    
 | 
						|
    return model, {
 | 
						|
        'train_losses': train_losses,
 | 
						|
        'val_losses': val_losses,
 | 
						|
        'train_accs': train_accs,
 | 
						|
        'val_accs': val_accs,
 | 
						|
        'best_val_acc': best_val_acc,
 | 
						|
        'total_time': total_time
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
def train_single_modal_model(model, train_loader, val_loader, criterion, optimizer, device, 
 | 
						|
                            num_epochs=50, save_dir='./models', model_name='single_modal', modal_key='diff_img'):
 | 
						|
    """训练单模态模型"""
 | 
						|
    # 保存最佳模型的路径
 | 
						|
    if not os.path.exists(save_dir):
 | 
						|
        os.makedirs(save_dir)
 | 
						|
    best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
 | 
						|
    
 | 
						|
    # 初始化早停
 | 
						|
    early_stopping = EarlyStopping(patience=10, path=best_model_path)
 | 
						|
    
 | 
						|
    # 跟踪训练历史
 | 
						|
    train_losses = []
 | 
						|
    val_losses = []
 | 
						|
    train_accs = []
 | 
						|
    val_accs = []
 | 
						|
    best_val_acc = 0.0
 | 
						|
    
 | 
						|
    # 开始训练
 | 
						|
    start_time = time.time()
 | 
						|
    
 | 
						|
    for epoch in range(num_epochs):
 | 
						|
        print(f'\nEpoch {epoch+1}/{num_epochs}')
 | 
						|
        print('-' * 20)
 | 
						|
        
 | 
						|
        # 训练
 | 
						|
        model.train()
 | 
						|
        train_loss = AverageMeter()
 | 
						|
        train_acc = AverageMeter()
 | 
						|
        
 | 
						|
        pbar = tqdm(train_loader, desc='训练')
 | 
						|
        for batch in pbar:
 | 
						|
            # 获取数据和标签
 | 
						|
            imgs = batch[modal_key].to(device)
 | 
						|
            labels = batch['label'].to(device)
 | 
						|
            
 | 
						|
            # 前向传播
 | 
						|
            outputs = model(imgs)
 | 
						|
            
 | 
						|
            # 计算损失
 | 
						|
            loss = criterion(outputs, labels)
 | 
						|
            
 | 
						|
            # 反向传播和优化
 | 
						|
            optimizer.zero_grad()
 | 
						|
            loss.backward()
 | 
						|
            optimizer.step()
 | 
						|
            
 | 
						|
            # 计算准确率
 | 
						|
            _, preds = torch.max(outputs, 1)
 | 
						|
            batch_acc = (preds == labels).float().mean()
 | 
						|
            
 | 
						|
            # 更新统计
 | 
						|
            train_loss.update(loss.item(), labels.size(0))
 | 
						|
            train_acc.update(batch_acc.item(), labels.size(0))
 | 
						|
            
 | 
						|
            # 更新进度条
 | 
						|
            pbar.set_postfix({
 | 
						|
                'loss': train_loss.avg,
 | 
						|
                'acc': train_acc.avg
 | 
						|
            })
 | 
						|
        
 | 
						|
        train_losses.append(train_loss.avg)
 | 
						|
        train_accs.append(train_acc.avg)
 | 
						|
        
 | 
						|
        # 验证
 | 
						|
        model.eval()
 | 
						|
        val_loss = AverageMeter()
 | 
						|
        val_acc = AverageMeter()
 | 
						|
        
 | 
						|
        all_labels = []
 | 
						|
        all_preds = []
 | 
						|
        
 | 
						|
        with torch.no_grad():
 | 
						|
            for batch in tqdm(val_loader, desc='验证'):
 | 
						|
                # 获取数据和标签
 | 
						|
                imgs = batch[modal_key].to(device)
 | 
						|
                labels = batch['label'].to(device)
 | 
						|
                
 | 
						|
                # 前向传播
 | 
						|
                outputs = model(imgs)
 | 
						|
                
 | 
						|
                # 计算损失
 | 
						|
                loss = criterion(outputs, labels)
 | 
						|
                
 | 
						|
                # 计算准确率
 | 
						|
                _, preds = torch.max(outputs, 1)
 | 
						|
                batch_acc = (preds == labels).float().mean()
 | 
						|
                
 | 
						|
                # 更新统计
 | 
						|
                val_loss.update(loss.item(), labels.size(0))
 | 
						|
                val_acc.update(batch_acc.item(), labels.size(0))
 | 
						|
                
 | 
						|
                # 保存预测结果用于计算指标
 | 
						|
                all_labels.extend(labels.cpu().numpy())
 | 
						|
                all_preds.extend(preds.cpu().numpy())
 | 
						|
        
 | 
						|
        val_losses.append(val_loss.avg)
 | 
						|
        val_accs.append(val_acc.avg)
 | 
						|
        
 | 
						|
        # 计算其他评估指标
 | 
						|
        val_metrics = compute_metrics(all_labels, all_preds)
 | 
						|
        
 | 
						|
        # 打印当前epoch的结果
 | 
						|
        print(f'训练损失: {train_loss.avg:.4f} 训练准确率: {train_acc.avg:.4f}')
 | 
						|
        print(f'验证损失: {val_loss.avg:.4f} 验证准确率: {val_acc.avg:.4f}')
 | 
						|
        print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
 | 
						|
        
 | 
						|
        # 检查是否为最佳验证准确率
 | 
						|
        if val_acc.avg > best_val_acc:
 | 
						|
            best_val_acc = val_acc.avg
 | 
						|
            torch.save({
 | 
						|
                'epoch': epoch,
 | 
						|
                'model_state_dict': model.state_dict(),
 | 
						|
                'optimizer_state_dict': optimizer.state_dict(),
 | 
						|
                'val_acc': val_acc.avg,
 | 
						|
                'val_metrics': val_metrics
 | 
						|
            }, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
 | 
						|
            print(f'保存新的最佳模型,验证准确率: {val_acc.avg:.4f}')
 | 
						|
        
 | 
						|
        # 早停检查
 | 
						|
        early_stopping(val_loss.avg, model)
 | 
						|
        if early_stopping.early_stop:
 | 
						|
            print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
 | 
						|
            break
 | 
						|
    
 | 
						|
    # 计算总训练时间
 | 
						|
    total_time = time.time() - start_time
 | 
						|
    print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
 | 
						|
    
 | 
						|
    # 加载最佳模型
 | 
						|
    model.load_state_dict(torch.load(best_model_path))
 | 
						|
    
 | 
						|
    return model, {
 | 
						|
        'train_losses': train_losses,
 | 
						|
        'val_losses': val_losses,
 | 
						|
        'train_accs': train_accs,
 | 
						|
        'val_accs': val_accs,
 | 
						|
        'best_val_acc': best_val_acc,
 | 
						|
        'total_time': total_time
 | 
						|
    } |