leukemia/evaluation/evaluate.py

185 lines
5.7 KiB
Python

import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc
from sklearn.manifold import TSNE
import seaborn as sns
from training.utils import compute_metrics, plot_confusion_matrix, save_results
def evaluate_model(model, data_loader, device, class_names=None):
"""评估模型性能"""
model.eval()
all_labels = []
all_preds = []
all_probs = []
with torch.no_grad():
for batch in data_loader:
# 获取数据和标签
diff_imgs = batch['diff_img'].to(device)
wnb_imgs = batch['wnb_img'].to(device)
labels = batch['label'].to(device)
# 前向传播
outputs = model(diff_imgs, wnb_imgs)
# 获取预测和概率
probs = torch.softmax(outputs, dim=1)
_, preds = torch.max(outputs, 1)
# 保存结果
all_labels.extend(labels.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
# 计算评估指标
metrics = compute_metrics(all_labels, all_preds, all_probs)
# 绘制混淆矩阵
if class_names is None:
class_names = ['正常', '异常']
plot_confusion_matrix(all_labels, all_preds, class_names)
# 输出评估结果
print(f"准确率: {metrics['accuracy']:.4f}")
print(f"精确率: {metrics['precision']:.4f}")
print(f"召回率: {metrics['recall']:.4f}")
print(f"F1值: {metrics['f1']:.4f}")
return metrics, all_labels, all_preds, np.array(all_probs)
def plot_roc_curve(all_labels, all_probs, save_path=None):
"""绘制ROC曲线"""
# 对于二分类问题,取阳性类(异常类别)的概率
pos_probs = all_probs[:, 1]
# 计算ROC曲线
fpr, tpr, thresholds = roc_curve(all_labels, pos_probs)
roc_auc = auc(fpr, tpr)
# 绘制ROC曲线
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('假阳性率')
plt.ylabel('真阳性率')
plt.title('受试者工作特征(ROC)曲线')
plt.legend(loc='lower right')
plt.grid(True)
if save_path:
plt.savefig(save_path)
plt.show()
return roc_auc
def extract_and_visualize_features(model, data_loader, device, save_path=None):
"""提取特征并使用t-SNE可视化"""
model.eval()
features_dict = {
'diff': [],
'wnb': []
}
all_labels = []
with torch.no_grad():
for batch in data_loader:
# 获取数据和标签
diff_imgs = batch['diff_img'].to(device)
wnb_imgs = batch['wnb_img'].to(device)
labels = batch['label'].cpu().numpy()
# 提取特征
batch_features = model.extract_features(diff_imgs, wnb_imgs)
# 保存特征和标签
for modality in features_dict:
features_dict[modality].extend(batch_features[modality].cpu().numpy())
all_labels.extend(labels)
# 转换为NumPy数组
all_labels = np.array(all_labels)
# 可视化每个模态的特征
plt.figure(figsize=(15, 5))
for i, (modality, features) in enumerate(features_dict.items()):
features = np.array(features)
# 使用t-SNE降维
tsne = TSNE(n_components=2, random_state=42)
features_tsne = tsne.fit_transform(features)
# 绘制t-SNE结果
plt.subplot(1, 3, i+1)
for label in np.unique(all_labels):
idx = all_labels == label
plt.scatter(features_tsne[idx, 0], features_tsne[idx, 1],
label=f"类别 {label}", alpha=0.7)
plt.title(f'{modality} 特征 t-SNE 可视化')
plt.xlabel('t-SNE 特征 1')
plt.ylabel('t-SNE 特征 2')
plt.legend()
plt.grid(True)
# 将两种模态的特征连接起来进行可视化
combined_features = np.concatenate([features_dict['diff'], features_dict['wnb']], axis=1)
tsne = TSNE(n_components=2, random_state=42)
combined_tsne = tsne.fit_transform(combined_features)
plt.subplot(1, 3, 3)
for label in np.unique(all_labels):
idx = all_labels == label
plt.scatter(combined_tsne[idx, 0], combined_tsne[idx, 1],
label=f"类别 {label}", alpha=0.7)
plt.title('融合特征 t-SNE 可视化')
plt.xlabel('t-SNE 特征 1')
plt.ylabel('t-SNE 特征 2')
plt.legend()
plt.grid(True)
plt.tight_layout()
if save_path:
plt.savefig(save_path)
plt.show()
def compare_models(models_results, model_names, save_path=None):
"""比较不同模型的性能"""
metrics = ['accuracy', 'precision', 'recall', 'f1']
values = []
for result in models_results:
values.append([result[metric] for metric in metrics])
values = np.array(values)
# 绘制条形图比较
plt.figure(figsize=(10, 6))
x = np.arange(len(metrics))
width = 0.8 / len(models_results)
for i, (name, vals) in enumerate(zip(model_names, values)):
plt.bar(x + i * width, vals, width, label=name)
plt.xlabel('评估指标')
plt.ylabel('分数')
plt.title('不同模型性能比较')
plt.xticks(x + width * (len(models_results) - 1) / 2, metrics)
plt.ylim(0, 1.0)
plt.legend()
plt.grid(True, axis='y')
if save_path:
plt.savefig(save_path)
plt.show()