126 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			126 lines
		
	
	
		
			3.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import numpy as np
 | 
						|
import matplotlib.pyplot as plt
 | 
						|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
 | 
						|
import seaborn as sns
 | 
						|
import os
 | 
						|
 | 
						|
class AverageMeter:
 | 
						|
    """跟踪平均值和当前值"""
 | 
						|
    def __init__(self):
 | 
						|
        self.reset()
 | 
						|
    
 | 
						|
    def reset(self):
 | 
						|
        self.val = 0
 | 
						|
        self.avg = 0
 | 
						|
        self.sum = 0
 | 
						|
        self.count = 0
 | 
						|
    
 | 
						|
    def update(self, val, n=1):
 | 
						|
        self.val = val
 | 
						|
        self.sum += val * n
 | 
						|
        self.count += n
 | 
						|
        self.avg = self.sum / self.count
 | 
						|
 | 
						|
 | 
						|
class EarlyStopping:
 | 
						|
    """提前停止训练,避免过拟合"""
 | 
						|
    def __init__(self, patience=7, delta=0, path='checkpoint.pt'):
 | 
						|
        self.patience = patience
 | 
						|
        self.delta = delta
 | 
						|
        self.path = path
 | 
						|
        self.counter = 0
 | 
						|
        self.best_score = None
 | 
						|
        self.early_stop = False
 | 
						|
    
 | 
						|
    def __call__(self, val_loss, model):
 | 
						|
        score = -val_loss
 | 
						|
        
 | 
						|
        if self.best_score is None:
 | 
						|
            self.best_score = score
 | 
						|
            self.save_checkpoint(val_loss, model)
 | 
						|
        elif score < self.best_score + self.delta:
 | 
						|
            self.counter += 1
 | 
						|
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
 | 
						|
            if self.counter >= self.patience:
 | 
						|
                self.early_stop = True
 | 
						|
        else:
 | 
						|
            self.best_score = score
 | 
						|
            self.save_checkpoint(val_loss, model)
 | 
						|
            self.counter = 0
 | 
						|
    
 | 
						|
    def save_checkpoint(self, val_loss, model):
 | 
						|
        torch.save(model.state_dict(), self.path)
 | 
						|
        print(f'验证损失降低 ({self.best_score:.6f} --> {-val_loss:.6f}). 保存模型...')
 | 
						|
 | 
						|
 | 
						|
def plot_training_curves(train_losses, val_losses, train_accs, val_accs, save_path=None):
 | 
						|
    """绘制训练和验证的损失与准确率曲线"""
 | 
						|
    plt.figure(figsize=(12, 5))
 | 
						|
    
 | 
						|
    plt.subplot(1, 2, 1)
 | 
						|
    plt.plot(train_losses, label='训练损失')
 | 
						|
    plt.plot(val_losses, label='验证损失')
 | 
						|
    plt.title('损失曲线')
 | 
						|
    plt.xlabel('Epoch')
 | 
						|
    plt.ylabel('损失')
 | 
						|
    plt.legend()
 | 
						|
    plt.grid(True)
 | 
						|
    
 | 
						|
    plt.subplot(1, 2, 2)
 | 
						|
    plt.plot(train_accs, label='训练准确率')
 | 
						|
    plt.plot(val_accs, label='验证准确率')
 | 
						|
    plt.title('准确率曲线')
 | 
						|
    plt.xlabel('Epoch')
 | 
						|
    plt.ylabel('准确率')
 | 
						|
    plt.legend()
 | 
						|
    plt.grid(True)
 | 
						|
    
 | 
						|
    plt.tight_layout()
 | 
						|
    
 | 
						|
    if save_path:
 | 
						|
        plt.savefig(save_path)
 | 
						|
    
 | 
						|
    plt.show()
 | 
						|
 | 
						|
 | 
						|
def plot_confusion_matrix(y_true, y_pred, class_names=None, save_path=None):
 | 
						|
    """绘制混淆矩阵"""
 | 
						|
    cm = confusion_matrix(y_true, y_pred)
 | 
						|
    
 | 
						|
    plt.figure(figsize=(8, 6))
 | 
						|
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
 | 
						|
                xticklabels=class_names if class_names else "auto",
 | 
						|
                yticklabels=class_names if class_names else "auto")
 | 
						|
    plt.title('混淆矩阵')
 | 
						|
    plt.xlabel('预测标签')
 | 
						|
    plt.ylabel('真实标签')
 | 
						|
    
 | 
						|
    if save_path:
 | 
						|
        plt.savefig(save_path)
 | 
						|
    
 | 
						|
    plt.show()
 | 
						|
 | 
						|
 | 
						|
def compute_metrics(y_true, y_pred, y_proba=None):
 | 
						|
    """计算各种评估指标"""
 | 
						|
    accuracy = accuracy_score(y_true, y_pred)
 | 
						|
    precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
 | 
						|
    recall = recall_score(y_true, y_pred, average='binary')
 | 
						|
    f1 = f1_score(y_true, y_pred, average='binary')
 | 
						|
    
 | 
						|
    metrics = {
 | 
						|
        'accuracy': accuracy,
 | 
						|
        'precision': precision,
 | 
						|
        'recall': recall,
 | 
						|
        'f1': f1
 | 
						|
    }
 | 
						|
    
 | 
						|
    return metrics
 | 
						|
 | 
						|
 | 
						|
def save_results(results, filename):
 | 
						|
    """保存结果到文本文件"""
 | 
						|
    with open(filename, 'w') as f:
 | 
						|
        for key, value in results.items():
 | 
						|
            f.write(f"{key}: {value}\n") |