185 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			185 lines
		
	
	
		
			5.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import torch
 | 
						|
import numpy as np
 | 
						|
import matplotlib.pyplot as plt
 | 
						|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
 | 
						|
from sklearn.manifold import TSNE
 | 
						|
import seaborn as sns
 | 
						|
from training.utils import compute_metrics, plot_confusion_matrix, save_results
 | 
						|
 | 
						|
def evaluate_model(model, data_loader, device, class_names=None):
 | 
						|
    """评估模型性能"""
 | 
						|
    model.eval()
 | 
						|
    
 | 
						|
    all_labels = []
 | 
						|
    all_preds = []
 | 
						|
    all_probs = []
 | 
						|
    
 | 
						|
    with torch.no_grad():
 | 
						|
        for batch in data_loader:
 | 
						|
            # 获取数据和标签
 | 
						|
            diff_imgs = batch['diff_img'].to(device)
 | 
						|
            wnb_imgs = batch['wnb_img'].to(device)
 | 
						|
            labels = batch['label'].to(device)
 | 
						|
            
 | 
						|
            # 前向传播
 | 
						|
            outputs = model(diff_imgs, wnb_imgs)
 | 
						|
            
 | 
						|
            # 获取预测和概率
 | 
						|
            probs = torch.softmax(outputs, dim=1)
 | 
						|
            _, preds = torch.max(outputs, 1)
 | 
						|
            
 | 
						|
            # 保存结果
 | 
						|
            all_labels.extend(labels.cpu().numpy())
 | 
						|
            all_preds.extend(preds.cpu().numpy())
 | 
						|
            all_probs.extend(probs.cpu().numpy())
 | 
						|
    
 | 
						|
    # 计算评估指标
 | 
						|
    metrics = compute_metrics(all_labels, all_preds, all_probs)
 | 
						|
    
 | 
						|
    # 绘制混淆矩阵
 | 
						|
    if class_names is None:
 | 
						|
        class_names = ['正常', '异常']
 | 
						|
    plot_confusion_matrix(all_labels, all_preds, class_names)
 | 
						|
    
 | 
						|
    # 输出评估结果
 | 
						|
    print(f"准确率: {metrics['accuracy']:.4f}")
 | 
						|
    print(f"精确率: {metrics['precision']:.4f}")
 | 
						|
    print(f"召回率: {metrics['recall']:.4f}")
 | 
						|
    print(f"F1值: {metrics['f1']:.4f}")
 | 
						|
    
 | 
						|
    return metrics, all_labels, all_preds, np.array(all_probs)
 | 
						|
 | 
						|
 | 
						|
def plot_roc_curve(all_labels, all_probs, save_path=None):
 | 
						|
    """绘制ROC曲线"""
 | 
						|
    # 对于二分类问题,取阳性类(异常类别)的概率
 | 
						|
    pos_probs = all_probs[:, 1]
 | 
						|
    
 | 
						|
    # 计算ROC曲线
 | 
						|
    fpr, tpr, thresholds = roc_curve(all_labels, pos_probs)
 | 
						|
    roc_auc = auc(fpr, tpr)
 | 
						|
    
 | 
						|
    # 绘制ROC曲线
 | 
						|
    plt.figure(figsize=(8, 6))
 | 
						|
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})')
 | 
						|
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
 | 
						|
    plt.xlim([0.0, 1.0])
 | 
						|
    plt.ylim([0.0, 1.05])
 | 
						|
    plt.xlabel('假阳性率')
 | 
						|
    plt.ylabel('真阳性率')
 | 
						|
    plt.title('受试者工作特征(ROC)曲线')
 | 
						|
    plt.legend(loc='lower right')
 | 
						|
    plt.grid(True)
 | 
						|
    
 | 
						|
    if save_path:
 | 
						|
        plt.savefig(save_path)
 | 
						|
    
 | 
						|
    plt.show()
 | 
						|
    
 | 
						|
    return roc_auc
 | 
						|
 | 
						|
 | 
						|
def extract_and_visualize_features(model, data_loader, device, save_path=None):
 | 
						|
    """提取特征并使用t-SNE可视化"""
 | 
						|
    model.eval()
 | 
						|
    
 | 
						|
    features_dict = {
 | 
						|
        'diff': [],
 | 
						|
        'wnb': []
 | 
						|
    }
 | 
						|
    all_labels = []
 | 
						|
    
 | 
						|
    with torch.no_grad():
 | 
						|
        for batch in data_loader:
 | 
						|
            # 获取数据和标签
 | 
						|
            diff_imgs = batch['diff_img'].to(device)
 | 
						|
            wnb_imgs = batch['wnb_img'].to(device)
 | 
						|
            labels = batch['label'].cpu().numpy()
 | 
						|
            
 | 
						|
            # 提取特征
 | 
						|
            batch_features = model.extract_features(diff_imgs, wnb_imgs)
 | 
						|
            
 | 
						|
            # 保存特征和标签
 | 
						|
            for modality in features_dict:
 | 
						|
                features_dict[modality].extend(batch_features[modality].cpu().numpy())
 | 
						|
            all_labels.extend(labels)
 | 
						|
    
 | 
						|
    # 转换为NumPy数组
 | 
						|
    all_labels = np.array(all_labels)
 | 
						|
    
 | 
						|
    # 可视化每个模态的特征
 | 
						|
    plt.figure(figsize=(15, 5))
 | 
						|
    
 | 
						|
    for i, (modality, features) in enumerate(features_dict.items()):
 | 
						|
        features = np.array(features)
 | 
						|
        
 | 
						|
        # 使用t-SNE降维
 | 
						|
        tsne = TSNE(n_components=2, random_state=42)
 | 
						|
        features_tsne = tsne.fit_transform(features)
 | 
						|
        
 | 
						|
        # 绘制t-SNE结果
 | 
						|
        plt.subplot(1, 3, i+1)
 | 
						|
        for label in np.unique(all_labels):
 | 
						|
            idx = all_labels == label
 | 
						|
            plt.scatter(features_tsne[idx, 0], features_tsne[idx, 1], 
 | 
						|
                       label=f"类别 {label}", alpha=0.7)
 | 
						|
        plt.title(f'{modality} 特征 t-SNE 可视化')
 | 
						|
        plt.xlabel('t-SNE 特征 1')
 | 
						|
        plt.ylabel('t-SNE 特征 2')
 | 
						|
        plt.legend()
 | 
						|
        plt.grid(True)
 | 
						|
    
 | 
						|
    # 将两种模态的特征连接起来进行可视化
 | 
						|
    combined_features = np.concatenate([features_dict['diff'], features_dict['wnb']], axis=1)
 | 
						|
    tsne = TSNE(n_components=2, random_state=42)
 | 
						|
    combined_tsne = tsne.fit_transform(combined_features)
 | 
						|
    
 | 
						|
    plt.subplot(1, 3, 3)
 | 
						|
    for label in np.unique(all_labels):
 | 
						|
        idx = all_labels == label
 | 
						|
        plt.scatter(combined_tsne[idx, 0], combined_tsne[idx, 1], 
 | 
						|
                   label=f"类别 {label}", alpha=0.7)
 | 
						|
    plt.title('融合特征 t-SNE 可视化')
 | 
						|
    plt.xlabel('t-SNE 特征 1')
 | 
						|
    plt.ylabel('t-SNE 特征 2')
 | 
						|
    plt.legend()
 | 
						|
    plt.grid(True)
 | 
						|
    
 | 
						|
    plt.tight_layout()
 | 
						|
    
 | 
						|
    if save_path:
 | 
						|
        plt.savefig(save_path)
 | 
						|
    
 | 
						|
    plt.show()
 | 
						|
 | 
						|
 | 
						|
def compare_models(models_results, model_names, save_path=None):
 | 
						|
    """比较不同模型的性能"""
 | 
						|
    metrics = ['accuracy', 'precision', 'recall', 'f1']
 | 
						|
    values = []
 | 
						|
    
 | 
						|
    for result in models_results:
 | 
						|
        values.append([result[metric] for metric in metrics])
 | 
						|
    
 | 
						|
    values = np.array(values)
 | 
						|
    
 | 
						|
    # 绘制条形图比较
 | 
						|
    plt.figure(figsize=(10, 6))
 | 
						|
    x = np.arange(len(metrics))
 | 
						|
    width = 0.8 / len(models_results)
 | 
						|
    
 | 
						|
    for i, (name, vals) in enumerate(zip(model_names, values)):
 | 
						|
        plt.bar(x + i * width, vals, width, label=name)
 | 
						|
    
 | 
						|
    plt.xlabel('评估指标')
 | 
						|
    plt.ylabel('分数')
 | 
						|
    plt.title('不同模型性能比较')
 | 
						|
    plt.xticks(x + width * (len(models_results) - 1) / 2, metrics)
 | 
						|
    plt.ylim(0, 1.0)
 | 
						|
    plt.legend()
 | 
						|
    plt.grid(True, axis='y')
 | 
						|
    
 | 
						|
    if save_path:
 | 
						|
        plt.savefig(save_path)
 | 
						|
    
 | 
						|
    plt.show() |