126 lines
3.6 KiB
Python
126 lines
3.6 KiB
Python
|
|
import torch
|
||
|
|
import numpy as np
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
|
||
|
|
import seaborn as sns
|
||
|
|
import os
|
||
|
|
|
||
|
|
class AverageMeter:
|
||
|
|
"""跟踪平均值和当前值"""
|
||
|
|
def __init__(self):
|
||
|
|
self.reset()
|
||
|
|
|
||
|
|
def reset(self):
|
||
|
|
self.val = 0
|
||
|
|
self.avg = 0
|
||
|
|
self.sum = 0
|
||
|
|
self.count = 0
|
||
|
|
|
||
|
|
def update(self, val, n=1):
|
||
|
|
self.val = val
|
||
|
|
self.sum += val * n
|
||
|
|
self.count += n
|
||
|
|
self.avg = self.sum / self.count
|
||
|
|
|
||
|
|
|
||
|
|
class EarlyStopping:
|
||
|
|
"""提前停止训练,避免过拟合"""
|
||
|
|
def __init__(self, patience=7, delta=0, path='checkpoint.pt'):
|
||
|
|
self.patience = patience
|
||
|
|
self.delta = delta
|
||
|
|
self.path = path
|
||
|
|
self.counter = 0
|
||
|
|
self.best_score = None
|
||
|
|
self.early_stop = False
|
||
|
|
|
||
|
|
def __call__(self, val_loss, model):
|
||
|
|
score = -val_loss
|
||
|
|
|
||
|
|
if self.best_score is None:
|
||
|
|
self.best_score = score
|
||
|
|
self.save_checkpoint(val_loss, model)
|
||
|
|
elif score < self.best_score + self.delta:
|
||
|
|
self.counter += 1
|
||
|
|
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||
|
|
if self.counter >= self.patience:
|
||
|
|
self.early_stop = True
|
||
|
|
else:
|
||
|
|
self.best_score = score
|
||
|
|
self.save_checkpoint(val_loss, model)
|
||
|
|
self.counter = 0
|
||
|
|
|
||
|
|
def save_checkpoint(self, val_loss, model):
|
||
|
|
torch.save(model.state_dict(), self.path)
|
||
|
|
print(f'验证损失降低 ({self.best_score:.6f} --> {-val_loss:.6f}). 保存模型...')
|
||
|
|
|
||
|
|
|
||
|
|
def plot_training_curves(train_losses, val_losses, train_accs, val_accs, save_path=None):
|
||
|
|
"""绘制训练和验证的损失与准确率曲线"""
|
||
|
|
plt.figure(figsize=(12, 5))
|
||
|
|
|
||
|
|
plt.subplot(1, 2, 1)
|
||
|
|
plt.plot(train_losses, label='训练损失')
|
||
|
|
plt.plot(val_losses, label='验证损失')
|
||
|
|
plt.title('损失曲线')
|
||
|
|
plt.xlabel('Epoch')
|
||
|
|
plt.ylabel('损失')
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True)
|
||
|
|
|
||
|
|
plt.subplot(1, 2, 2)
|
||
|
|
plt.plot(train_accs, label='训练准确率')
|
||
|
|
plt.plot(val_accs, label='验证准确率')
|
||
|
|
plt.title('准确率曲线')
|
||
|
|
plt.xlabel('Epoch')
|
||
|
|
plt.ylabel('准确率')
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True)
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
|
||
|
|
if save_path:
|
||
|
|
plt.savefig(save_path)
|
||
|
|
|
||
|
|
plt.show()
|
||
|
|
|
||
|
|
|
||
|
|
def plot_confusion_matrix(y_true, y_pred, class_names=None, save_path=None):
|
||
|
|
"""绘制混淆矩阵"""
|
||
|
|
cm = confusion_matrix(y_true, y_pred)
|
||
|
|
|
||
|
|
plt.figure(figsize=(8, 6))
|
||
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
|
||
|
|
xticklabels=class_names if class_names else "auto",
|
||
|
|
yticklabels=class_names if class_names else "auto")
|
||
|
|
plt.title('混淆矩阵')
|
||
|
|
plt.xlabel('预测标签')
|
||
|
|
plt.ylabel('真实标签')
|
||
|
|
|
||
|
|
if save_path:
|
||
|
|
plt.savefig(save_path)
|
||
|
|
|
||
|
|
plt.show()
|
||
|
|
|
||
|
|
|
||
|
|
def compute_metrics(y_true, y_pred, y_proba=None):
|
||
|
|
"""计算各种评估指标"""
|
||
|
|
accuracy = accuracy_score(y_true, y_pred)
|
||
|
|
precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
|
||
|
|
recall = recall_score(y_true, y_pred, average='binary')
|
||
|
|
f1 = f1_score(y_true, y_pred, average='binary')
|
||
|
|
|
||
|
|
metrics = {
|
||
|
|
'accuracy': accuracy,
|
||
|
|
'precision': precision,
|
||
|
|
'recall': recall,
|
||
|
|
'f1': f1
|
||
|
|
}
|
||
|
|
|
||
|
|
return metrics
|
||
|
|
|
||
|
|
|
||
|
|
def save_results(results, filename):
|
||
|
|
"""保存结果到文本文件"""
|
||
|
|
with open(filename, 'w') as f:
|
||
|
|
for key, value in results.items():
|
||
|
|
f.write(f"{key}: {value}\n")
|