leukemia/training/utils.py

126 lines
3.6 KiB
Python

import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import os
class AverageMeter:
"""跟踪平均值和当前值"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class EarlyStopping:
"""提前停止训练,避免过拟合"""
def __init__(self, patience=7, delta=0, path='checkpoint.pt'):
self.patience = patience
self.delta = delta
self.path = path
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
torch.save(model.state_dict(), self.path)
print(f'验证损失降低 ({self.best_score:.6f} --> {-val_loss:.6f}). 保存模型...')
def plot_training_curves(train_losses, val_losses, train_accs, val_accs, save_path=None):
"""绘制训练和验证的损失与准确率曲线"""
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='训练损失')
plt.plot(val_losses, label='验证损失')
plt.title('损失曲线')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='训练准确率')
plt.plot(val_accs, label='验证准确率')
plt.title('准确率曲线')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.show()
def plot_confusion_matrix(y_true, y_pred, class_names=None, save_path=None):
"""绘制混淆矩阵"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names if class_names else "auto",
yticklabels=class_names if class_names else "auto")
plt.title('混淆矩阵')
plt.xlabel('预测标签')
plt.ylabel('真实标签')
if save_path:
plt.savefig(save_path)
plt.show()
def compute_metrics(y_true, y_pred, y_proba=None):
"""计算各种评估指标"""
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
recall = recall_score(y_true, y_pred, average='binary')
f1 = f1_score(y_true, y_pred, average='binary')
metrics = {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1
}
return metrics
def save_results(results, filename):
"""保存结果到文本文件"""
with open(filename, 'w') as f:
for key, value in results.items():
f.write(f"{key}: {value}\n")