import torch import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc from sklearn.manifold import TSNE import seaborn as sns from training.utils import compute_metrics, plot_confusion_matrix, save_results def evaluate_model(model, data_loader, device, class_names=None): """评估模型性能""" model.eval() all_labels = [] all_preds = [] all_probs = [] with torch.no_grad(): for batch in data_loader: # 获取数据和标签 diff_imgs = batch['diff_img'].to(device) wnb_imgs = batch['wnb_img'].to(device) labels = batch['label'].to(device) # 前向传播 outputs = model(diff_imgs, wnb_imgs) # 获取预测和概率 probs = torch.softmax(outputs, dim=1) _, preds = torch.max(outputs, 1) # 保存结果 all_labels.extend(labels.cpu().numpy()) all_preds.extend(preds.cpu().numpy()) all_probs.extend(probs.cpu().numpy()) # 计算评估指标 metrics = compute_metrics(all_labels, all_preds, all_probs) # 绘制混淆矩阵 if class_names is None: class_names = ['正常', '异常'] plot_confusion_matrix(all_labels, all_preds, class_names) # 输出评估结果 print(f"准确率: {metrics['accuracy']:.4f}") print(f"精确率: {metrics['precision']:.4f}") print(f"召回率: {metrics['recall']:.4f}") print(f"F1值: {metrics['f1']:.4f}") return metrics, all_labels, all_preds, np.array(all_probs) def plot_roc_curve(all_labels, all_probs, save_path=None): """绘制ROC曲线""" # 对于二分类问题,取阳性类(异常类别)的概率 pos_probs = all_probs[:, 1] # 计算ROC曲线 fpr, tpr, thresholds = roc_curve(all_labels, pos_probs) roc_auc = auc(fpr, tpr) # 绘制ROC曲线 plt.figure(figsize=(8, 6)) plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})') plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('假阳性率') plt.ylabel('真阳性率') plt.title('受试者工作特征(ROC)曲线') plt.legend(loc='lower right') plt.grid(True) if save_path: plt.savefig(save_path) plt.show() return roc_auc def extract_and_visualize_features(model, data_loader, device, save_path=None): """提取特征并使用t-SNE可视化""" model.eval() features_dict = { 'diff': [], 'wnb': [] } all_labels = [] with torch.no_grad(): for batch in data_loader: # 获取数据和标签 diff_imgs = batch['diff_img'].to(device) wnb_imgs = batch['wnb_img'].to(device) labels = batch['label'].cpu().numpy() # 提取特征 batch_features = model.extract_features(diff_imgs, wnb_imgs) # 保存特征和标签 for modality in features_dict: features_dict[modality].extend(batch_features[modality].cpu().numpy()) all_labels.extend(labels) # 转换为NumPy数组 all_labels = np.array(all_labels) # 可视化每个模态的特征 plt.figure(figsize=(15, 5)) for i, (modality, features) in enumerate(features_dict.items()): features = np.array(features) # 使用t-SNE降维 tsne = TSNE(n_components=2, random_state=42) features_tsne = tsne.fit_transform(features) # 绘制t-SNE结果 plt.subplot(1, 3, i+1) for label in np.unique(all_labels): idx = all_labels == label plt.scatter(features_tsne[idx, 0], features_tsne[idx, 1], label=f"类别 {label}", alpha=0.7) plt.title(f'{modality} 特征 t-SNE 可视化') plt.xlabel('t-SNE 特征 1') plt.ylabel('t-SNE 特征 2') plt.legend() plt.grid(True) # 将两种模态的特征连接起来进行可视化 combined_features = np.concatenate([features_dict['diff'], features_dict['wnb']], axis=1) tsne = TSNE(n_components=2, random_state=42) combined_tsne = tsne.fit_transform(combined_features) plt.subplot(1, 3, 3) for label in np.unique(all_labels): idx = all_labels == label plt.scatter(combined_tsne[idx, 0], combined_tsne[idx, 1], label=f"类别 {label}", alpha=0.7) plt.title('融合特征 t-SNE 可视化') plt.xlabel('t-SNE 特征 1') plt.ylabel('t-SNE 特征 2') plt.legend() plt.grid(True) plt.tight_layout() if save_path: plt.savefig(save_path) plt.show() def compare_models(models_results, model_names, save_path=None): """比较不同模型的性能""" metrics = ['accuracy', 'precision', 'recall', 'f1'] values = [] for result in models_results: values.append([result[metric] for metric in metrics]) values = np.array(values) # 绘制条形图比较 plt.figure(figsize=(10, 6)) x = np.arange(len(metrics)) width = 0.8 / len(models_results) for i, (name, vals) in enumerate(zip(model_names, values)): plt.bar(x + i * width, vals, width, label=name) plt.xlabel('评估指标') plt.ylabel('分数') plt.title('不同模型性能比较') plt.xticks(x + width * (len(models_results) - 1) / 2, metrics) plt.ylim(0, 1.0) plt.legend() plt.grid(True, axis='y') if save_path: plt.savefig(save_path) plt.show()