185 lines
5.7 KiB
Python
185 lines
5.7 KiB
Python
|
|
import torch
|
||
|
|
import numpy as np
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
|
||
|
|
from sklearn.manifold import TSNE
|
||
|
|
import seaborn as sns
|
||
|
|
from training.utils import compute_metrics, plot_confusion_matrix, save_results
|
||
|
|
|
||
|
|
def evaluate_model(model, data_loader, device, class_names=None):
|
||
|
|
"""评估模型性能"""
|
||
|
|
model.eval()
|
||
|
|
|
||
|
|
all_labels = []
|
||
|
|
all_preds = []
|
||
|
|
all_probs = []
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
for batch in data_loader:
|
||
|
|
# 获取数据和标签
|
||
|
|
diff_imgs = batch['diff_img'].to(device)
|
||
|
|
wnb_imgs = batch['wnb_img'].to(device)
|
||
|
|
labels = batch['label'].to(device)
|
||
|
|
|
||
|
|
# 前向传播
|
||
|
|
outputs = model(diff_imgs, wnb_imgs)
|
||
|
|
|
||
|
|
# 获取预测和概率
|
||
|
|
probs = torch.softmax(outputs, dim=1)
|
||
|
|
_, preds = torch.max(outputs, 1)
|
||
|
|
|
||
|
|
# 保存结果
|
||
|
|
all_labels.extend(labels.cpu().numpy())
|
||
|
|
all_preds.extend(preds.cpu().numpy())
|
||
|
|
all_probs.extend(probs.cpu().numpy())
|
||
|
|
|
||
|
|
# 计算评估指标
|
||
|
|
metrics = compute_metrics(all_labels, all_preds, all_probs)
|
||
|
|
|
||
|
|
# 绘制混淆矩阵
|
||
|
|
if class_names is None:
|
||
|
|
class_names = ['正常', '异常']
|
||
|
|
plot_confusion_matrix(all_labels, all_preds, class_names)
|
||
|
|
|
||
|
|
# 输出评估结果
|
||
|
|
print(f"准确率: {metrics['accuracy']:.4f}")
|
||
|
|
print(f"精确率: {metrics['precision']:.4f}")
|
||
|
|
print(f"召回率: {metrics['recall']:.4f}")
|
||
|
|
print(f"F1值: {metrics['f1']:.4f}")
|
||
|
|
|
||
|
|
return metrics, all_labels, all_preds, np.array(all_probs)
|
||
|
|
|
||
|
|
|
||
|
|
def plot_roc_curve(all_labels, all_probs, save_path=None):
|
||
|
|
"""绘制ROC曲线"""
|
||
|
|
# 对于二分类问题,取阳性类(异常类别)的概率
|
||
|
|
pos_probs = all_probs[:, 1]
|
||
|
|
|
||
|
|
# 计算ROC曲线
|
||
|
|
fpr, tpr, thresholds = roc_curve(all_labels, pos_probs)
|
||
|
|
roc_auc = auc(fpr, tpr)
|
||
|
|
|
||
|
|
# 绘制ROC曲线
|
||
|
|
plt.figure(figsize=(8, 6))
|
||
|
|
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})')
|
||
|
|
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
|
||
|
|
plt.xlim([0.0, 1.0])
|
||
|
|
plt.ylim([0.0, 1.05])
|
||
|
|
plt.xlabel('假阳性率')
|
||
|
|
plt.ylabel('真阳性率')
|
||
|
|
plt.title('受试者工作特征(ROC)曲线')
|
||
|
|
plt.legend(loc='lower right')
|
||
|
|
plt.grid(True)
|
||
|
|
|
||
|
|
if save_path:
|
||
|
|
plt.savefig(save_path)
|
||
|
|
|
||
|
|
plt.show()
|
||
|
|
|
||
|
|
return roc_auc
|
||
|
|
|
||
|
|
|
||
|
|
def extract_and_visualize_features(model, data_loader, device, save_path=None):
|
||
|
|
"""提取特征并使用t-SNE可视化"""
|
||
|
|
model.eval()
|
||
|
|
|
||
|
|
features_dict = {
|
||
|
|
'diff': [],
|
||
|
|
'wnb': []
|
||
|
|
}
|
||
|
|
all_labels = []
|
||
|
|
|
||
|
|
with torch.no_grad():
|
||
|
|
for batch in data_loader:
|
||
|
|
# 获取数据和标签
|
||
|
|
diff_imgs = batch['diff_img'].to(device)
|
||
|
|
wnb_imgs = batch['wnb_img'].to(device)
|
||
|
|
labels = batch['label'].cpu().numpy()
|
||
|
|
|
||
|
|
# 提取特征
|
||
|
|
batch_features = model.extract_features(diff_imgs, wnb_imgs)
|
||
|
|
|
||
|
|
# 保存特征和标签
|
||
|
|
for modality in features_dict:
|
||
|
|
features_dict[modality].extend(batch_features[modality].cpu().numpy())
|
||
|
|
all_labels.extend(labels)
|
||
|
|
|
||
|
|
# 转换为NumPy数组
|
||
|
|
all_labels = np.array(all_labels)
|
||
|
|
|
||
|
|
# 可视化每个模态的特征
|
||
|
|
plt.figure(figsize=(15, 5))
|
||
|
|
|
||
|
|
for i, (modality, features) in enumerate(features_dict.items()):
|
||
|
|
features = np.array(features)
|
||
|
|
|
||
|
|
# 使用t-SNE降维
|
||
|
|
tsne = TSNE(n_components=2, random_state=42)
|
||
|
|
features_tsne = tsne.fit_transform(features)
|
||
|
|
|
||
|
|
# 绘制t-SNE结果
|
||
|
|
plt.subplot(1, 3, i+1)
|
||
|
|
for label in np.unique(all_labels):
|
||
|
|
idx = all_labels == label
|
||
|
|
plt.scatter(features_tsne[idx, 0], features_tsne[idx, 1],
|
||
|
|
label=f"类别 {label}", alpha=0.7)
|
||
|
|
plt.title(f'{modality} 特征 t-SNE 可视化')
|
||
|
|
plt.xlabel('t-SNE 特征 1')
|
||
|
|
plt.ylabel('t-SNE 特征 2')
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True)
|
||
|
|
|
||
|
|
# 将两种模态的特征连接起来进行可视化
|
||
|
|
combined_features = np.concatenate([features_dict['diff'], features_dict['wnb']], axis=1)
|
||
|
|
tsne = TSNE(n_components=2, random_state=42)
|
||
|
|
combined_tsne = tsne.fit_transform(combined_features)
|
||
|
|
|
||
|
|
plt.subplot(1, 3, 3)
|
||
|
|
for label in np.unique(all_labels):
|
||
|
|
idx = all_labels == label
|
||
|
|
plt.scatter(combined_tsne[idx, 0], combined_tsne[idx, 1],
|
||
|
|
label=f"类别 {label}", alpha=0.7)
|
||
|
|
plt.title('融合特征 t-SNE 可视化')
|
||
|
|
plt.xlabel('t-SNE 特征 1')
|
||
|
|
plt.ylabel('t-SNE 特征 2')
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True)
|
||
|
|
|
||
|
|
plt.tight_layout()
|
||
|
|
|
||
|
|
if save_path:
|
||
|
|
plt.savefig(save_path)
|
||
|
|
|
||
|
|
plt.show()
|
||
|
|
|
||
|
|
|
||
|
|
def compare_models(models_results, model_names, save_path=None):
|
||
|
|
"""比较不同模型的性能"""
|
||
|
|
metrics = ['accuracy', 'precision', 'recall', 'f1']
|
||
|
|
values = []
|
||
|
|
|
||
|
|
for result in models_results:
|
||
|
|
values.append([result[metric] for metric in metrics])
|
||
|
|
|
||
|
|
values = np.array(values)
|
||
|
|
|
||
|
|
# 绘制条形图比较
|
||
|
|
plt.figure(figsize=(10, 6))
|
||
|
|
x = np.arange(len(metrics))
|
||
|
|
width = 0.8 / len(models_results)
|
||
|
|
|
||
|
|
for i, (name, vals) in enumerate(zip(model_names, values)):
|
||
|
|
plt.bar(x + i * width, vals, width, label=name)
|
||
|
|
|
||
|
|
plt.xlabel('评估指标')
|
||
|
|
plt.ylabel('分数')
|
||
|
|
plt.title('不同模型性能比较')
|
||
|
|
plt.xticks(x + width * (len(models_results) - 1) / 2, metrics)
|
||
|
|
plt.ylim(0, 1.0)
|
||
|
|
plt.legend()
|
||
|
|
plt.grid(True, axis='y')
|
||
|
|
|
||
|
|
if save_path:
|
||
|
|
plt.savefig(save_path)
|
||
|
|
|
||
|
|
plt.show()
|