import torch import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix import seaborn as sns import os class AverageMeter: """跟踪平均值和当前值""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count class EarlyStopping: """提前停止训练,避免过拟合""" def __init__(self, patience=7, delta=0, path='checkpoint.pt'): self.patience = patience self.delta = delta self.path = path self.counter = 0 self.best_score = None self.early_stop = False def __call__(self, val_loss, model): score = -val_loss if self.best_score is None: self.best_score = score self.save_checkpoint(val_loss, model) elif score < self.best_score + self.delta: self.counter += 1 print(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: self.best_score = score self.save_checkpoint(val_loss, model) self.counter = 0 def save_checkpoint(self, val_loss, model): torch.save(model.state_dict(), self.path) print(f'验证损失降低 ({self.best_score:.6f} --> {-val_loss:.6f}). 保存模型...') def plot_training_curves(train_losses, val_losses, train_accs, val_accs, save_path=None): """绘制训练和验证的损失与准确率曲线""" plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) plt.plot(train_losses, label='训练损失') plt.plot(val_losses, label='验证损失') plt.title('损失曲线') plt.xlabel('Epoch') plt.ylabel('损失') plt.legend() plt.grid(True) plt.subplot(1, 2, 2) plt.plot(train_accs, label='训练准确率') plt.plot(val_accs, label='验证准确率') plt.title('准确率曲线') plt.xlabel('Epoch') plt.ylabel('准确率') plt.legend() plt.grid(True) plt.tight_layout() if save_path: plt.savefig(save_path) plt.show() def plot_confusion_matrix(y_true, y_pred, class_names=None, save_path=None): """绘制混淆矩阵""" cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names if class_names else "auto", yticklabels=class_names if class_names else "auto") plt.title('混淆矩阵') plt.xlabel('预测标签') plt.ylabel('真实标签') if save_path: plt.savefig(save_path) plt.show() def compute_metrics(y_true, y_pred, y_proba=None): """计算各种评估指标""" accuracy = accuracy_score(y_true, y_pred) precision = precision_score(y_true, y_pred, average='binary', zero_division=0) recall = recall_score(y_true, y_pred, average='binary') f1 = f1_score(y_true, y_pred, average='binary') metrics = { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1 } return metrics def save_results(results, filename): """保存结果到文本文件""" with open(filename, 'w') as f: for key, value in results.items(): f.write(f"{key}: {value}\n")