leukemia/training/train.py

312 lines
10 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
from tqdm import tqdm
import time
from training.utils import AverageMeter, EarlyStopping, plot_training_curves, compute_metrics
def train_epoch(model, train_loader, criterion, optimizer, device):
"""训练一个epoch"""
model.train()
losses = AverageMeter()
acc = AverageMeter()
# 进度条
pbar = tqdm(train_loader, desc='训练')
for batch in pbar:
# 获取数据和标签
diff_imgs = batch['diff_img'].to(device)
wnb_imgs = batch['wnb_img'].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(diff_imgs, wnb_imgs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算准确率
_, preds = torch.max(outputs, 1)
batch_acc = (preds == labels).float().mean()
# 更新统计
losses.update(loss.item(), labels.size(0))
acc.update(batch_acc.item(), labels.size(0))
# 更新进度条
pbar.set_postfix({
'loss': losses.avg,
'acc': acc.avg
})
return losses.avg, acc.avg
def validate(model, val_loader, criterion, device):
"""验证模型"""
model.eval()
losses = AverageMeter()
acc = AverageMeter()
all_labels = []
all_preds = []
all_probs = []
with torch.no_grad():
for batch in tqdm(val_loader, desc='验证'):
# 获取数据和标签
diff_imgs = batch['diff_img'].to(device)
wnb_imgs = batch['wnb_img'].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(diff_imgs, wnb_imgs)
# 计算损失
loss = criterion(outputs, labels)
# 计算准确率
probs = torch.softmax(outputs, dim=1)
_, preds = torch.max(outputs, 1)
batch_acc = (preds == labels).float().mean()
# 更新统计
losses.update(loss.item(), labels.size(0))
acc.update(batch_acc.item(), labels.size(0))
# 保存预测结果用于计算指标
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# 计算其他评估指标
metrics = compute_metrics(all_labels, all_preds, all_probs)
metrics['loss'] = losses.avg
metrics['accuracy'] = acc.avg
return losses.avg, acc.avg, metrics
def train_model(model, train_loader, val_loader, criterion, optimizer, device,
num_epochs=50, save_dir='./models', model_name='model'):
"""训练模型"""
# 保存最佳模型的路径
if not os.path.exists(save_dir):
os.makedirs(save_dir)
best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
# 初始化早停
early_stopping = EarlyStopping(patience=10, path=best_model_path)
# 跟踪训练历史
train_losses = []
val_losses = []
train_accs = []
val_accs = []
best_val_acc = 0.0
# 开始训练
start_time = time.time()
for epoch in range(num_epochs):
print(f'\nEpoch {epoch+1}/{num_epochs}')
print('-' * 20)
# 训练
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
train_losses.append(train_loss)
train_accs.append(train_acc)
# 验证
val_loss, val_acc, val_metrics = validate(model, val_loader, criterion, device)
val_losses.append(val_loss)
val_accs.append(val_acc)
# 打印当前epoch的结果
print(f'训练损失: {train_loss:.4f} 训练准确率: {train_acc:.4f}')
print(f'验证损失: {val_loss:.4f} 验证准确率: {val_acc:.4f}')
print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
# 检查是否为最佳验证准确率
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc,
'val_metrics': val_metrics
}, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
print(f'保存新的最佳模型,验证准确率: {val_acc:.4f}')
# 早停检查
early_stopping(val_loss, model)
if early_stopping.early_stop:
print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
break
# 计算总训练时间
total_time = time.time() - start_time
print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
# 加载最佳模型
model.load_state_dict(torch.load(best_model_path))
return model, {
'train_losses': train_losses,
'val_losses': val_losses,
'train_accs': train_accs,
'val_accs': val_accs,
'best_val_acc': best_val_acc,
'total_time': total_time
}
def train_single_modal_model(model, train_loader, val_loader, criterion, optimizer, device,
num_epochs=50, save_dir='./models', model_name='single_modal', modal_key='diff_img'):
"""训练单模态模型"""
# 保存最佳模型的路径
if not os.path.exists(save_dir):
os.makedirs(save_dir)
best_model_path = os.path.join(save_dir, f'{model_name}_best.pth')
# 初始化早停
early_stopping = EarlyStopping(patience=10, path=best_model_path)
# 跟踪训练历史
train_losses = []
val_losses = []
train_accs = []
val_accs = []
best_val_acc = 0.0
# 开始训练
start_time = time.time()
for epoch in range(num_epochs):
print(f'\nEpoch {epoch+1}/{num_epochs}')
print('-' * 20)
# 训练
model.train()
train_loss = AverageMeter()
train_acc = AverageMeter()
pbar = tqdm(train_loader, desc='训练')
for batch in pbar:
# 获取数据和标签
imgs = batch[modal_key].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(imgs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 计算准确率
_, preds = torch.max(outputs, 1)
batch_acc = (preds == labels).float().mean()
# 更新统计
train_loss.update(loss.item(), labels.size(0))
train_acc.update(batch_acc.item(), labels.size(0))
# 更新进度条
pbar.set_postfix({
'loss': train_loss.avg,
'acc': train_acc.avg
})
train_losses.append(train_loss.avg)
train_accs.append(train_acc.avg)
# 验证
model.eval()
val_loss = AverageMeter()
val_acc = AverageMeter()
all_labels = []
all_preds = []
with torch.no_grad():
for batch in tqdm(val_loader, desc='验证'):
# 获取数据和标签
imgs = batch[modal_key].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(imgs)
# 计算损失
loss = criterion(outputs, labels)
# 计算准确率
_, preds = torch.max(outputs, 1)
batch_acc = (preds == labels).float().mean()
# 更新统计
val_loss.update(loss.item(), labels.size(0))
val_acc.update(batch_acc.item(), labels.size(0))
# 保存预测结果用于计算指标
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
val_losses.append(val_loss.avg)
val_accs.append(val_acc.avg)
# 计算其他评估指标
val_metrics = compute_metrics(all_labels, all_preds)
# 打印当前epoch的结果
print(f'训练损失: {train_loss.avg:.4f} 训练准确率: {train_acc.avg:.4f}')
print(f'验证损失: {val_loss.avg:.4f} 验证准确率: {val_acc.avg:.4f}')
print(f'验证指标: 精确率={val_metrics["precision"]:.4f}, 召回率={val_metrics["recall"]:.4f}, F1={val_metrics["f1"]:.4f}')
# 检查是否为最佳验证准确率
if val_acc.avg > best_val_acc:
best_val_acc = val_acc.avg
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'val_acc': val_acc.avg,
'val_metrics': val_metrics
}, os.path.join(save_dir, f'{model_name}_best_acc.pth'))
print(f'保存新的最佳模型,验证准确率: {val_acc.avg:.4f}')
# 早停检查
early_stopping(val_loss.avg, model)
if early_stopping.early_stop:
print(f"早停! 在第 {epoch+1} 个Epoch停止训练")
break
# 计算总训练时间
total_time = time.time() - start_time
print(f'训练完成! 总用时: {total_time/60:.2f} 分钟')
# 加载最佳模型
model.load_state_dict(torch.load(best_model_path))
return model, {
'train_losses': train_losses,
'val_losses': val_losses,
'train_accs': train_accs,
'val_accs': val_accs,
'best_val_acc': best_val_acc,
'total_time': total_time
}