171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
|
|
import os
|
|||
|
|
import glob
|
|||
|
|
import random
|
|||
|
|
import numpy as np
|
|||
|
|
from PIL import Image
|
|||
|
|
import torch
|
|||
|
|
from torch.utils.data import Dataset, DataLoader
|
|||
|
|
from torchvision import transforms
|
|||
|
|
from sklearn.model_selection import train_test_split
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LeukemiaDataset(Dataset):
|
|||
|
|
def __init__(self, sample_paths, labels, transform=None):
|
|||
|
|
"""
|
|||
|
|
初始化白血病数据集
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
sample_paths: 样本路径列表,每个元素是(diff_path, wnb_path)元组
|
|||
|
|
labels: 对应的标签列表 (0: 正常, 1: 异常)
|
|||
|
|
transform: 图像转换操作
|
|||
|
|
"""
|
|||
|
|
self.sample_paths = sample_paths
|
|||
|
|
self.labels = labels
|
|||
|
|
self.transform = transform
|
|||
|
|
|
|||
|
|
def __len__(self):
|
|||
|
|
return len(self.sample_paths)
|
|||
|
|
|
|||
|
|
def __getitem__(self, idx):
|
|||
|
|
diff_path, wnb_path = self.sample_paths[idx]
|
|||
|
|
|
|||
|
|
# 从路径中提取样本ID
|
|||
|
|
sample_id = os.path.basename(diff_path).replace('Diff.png', '')
|
|||
|
|
|
|||
|
|
# 加载DIFF散点图
|
|||
|
|
diff_img = Image.open(diff_path).convert('RGB')
|
|||
|
|
|
|||
|
|
# 加载WNB散点图
|
|||
|
|
wnb_img = Image.open(wnb_path).convert('RGB')
|
|||
|
|
|
|||
|
|
if self.transform:
|
|||
|
|
diff_img = self.transform(diff_img)
|
|||
|
|
wnb_img = self.transform(wnb_img)
|
|||
|
|
|
|||
|
|
label = self.labels[idx]
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
'id': sample_id,
|
|||
|
|
'diff_img': diff_img,
|
|||
|
|
'wnb_img': wnb_img,
|
|||
|
|
'label': torch.tensor(label, dtype=torch.long)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_data(data_root, img_size=224, batch_size=32, num_workers=4, train_ratio=0.8,
|
|||
|
|
max_samples_per_class=500, normal_class="000", abnormal_classes=None):
|
|||
|
|
"""
|
|||
|
|
加载和预处理数据
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
data_root: 数据根目录,包含000,001等子文件夹
|
|||
|
|
img_size: 调整图像大小
|
|||
|
|
batch_size: 批量大小
|
|||
|
|
num_workers: 数据加载的工作线程数
|
|||
|
|
train_ratio: 训练集比例
|
|||
|
|
max_samples_per_class: 每个类别最多读取的样本数
|
|||
|
|
normal_class: 正常类别的文件夹名
|
|||
|
|
abnormal_classes: 异常类别的文件夹名列表
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
train_loader: 训练数据加载器
|
|||
|
|
val_loader: 验证数据加载器
|
|||
|
|
"""
|
|||
|
|
if abnormal_classes is None:
|
|||
|
|
abnormal_classes = ["001", "010", "011", "100", "101", "110", "111"]
|
|||
|
|
|
|||
|
|
# 图像转换
|
|||
|
|
transform = transforms.Compose([
|
|||
|
|
transforms.Resize((img_size, img_size)),
|
|||
|
|
transforms.ToTensor(),
|
|||
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# 收集正常样本数据
|
|||
|
|
normal_samples = []
|
|||
|
|
normal_folder = os.path.join(data_root, normal_class)
|
|||
|
|
diff_files = glob.glob(os.path.join(normal_folder, "*Diff.png"))
|
|||
|
|
|
|||
|
|
# 限制正常样本数量
|
|||
|
|
if len(diff_files) > max_samples_per_class:
|
|||
|
|
random.shuffle(diff_files)
|
|||
|
|
diff_files = diff_files[:max_samples_per_class]
|
|||
|
|
|
|||
|
|
for diff_file in diff_files:
|
|||
|
|
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
|
|||
|
|
wnb_file = os.path.join(normal_folder, f"{sample_id}Wnb.png")
|
|||
|
|
|
|||
|
|
# 确保WNB文件存在
|
|||
|
|
if os.path.exists(wnb_file):
|
|||
|
|
normal_samples.append((diff_file, wnb_file))
|
|||
|
|
|
|||
|
|
# 收集异常样本数据
|
|||
|
|
abnormal_samples = []
|
|||
|
|
|
|||
|
|
# 计算每个异常类别分配的样本数
|
|||
|
|
# 如果max_samples_per_class=500,总共取500个异常样本,平均分配给各个异常类别
|
|||
|
|
samples_per_abnormal_class = max_samples_per_class // len(abnormal_classes)
|
|||
|
|
|
|||
|
|
for abnormal_class in abnormal_classes:
|
|||
|
|
abnormal_folder = os.path.join(data_root, abnormal_class)
|
|||
|
|
diff_files = glob.glob(os.path.join(abnormal_folder, "*Diff.png"))
|
|||
|
|
|
|||
|
|
# 限制每个异常类别的样本数量
|
|||
|
|
if len(diff_files) > samples_per_abnormal_class:
|
|||
|
|
random.shuffle(diff_files)
|
|||
|
|
diff_files = diff_files[:samples_per_abnormal_class]
|
|||
|
|
|
|||
|
|
for diff_file in diff_files:
|
|||
|
|
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
|
|||
|
|
wnb_file = os.path.join(abnormal_folder, f"{sample_id}Wnb.png")
|
|||
|
|
|
|||
|
|
# 确保WNB文件存在
|
|||
|
|
if os.path.exists(wnb_file):
|
|||
|
|
abnormal_samples.append((diff_file, wnb_file))
|
|||
|
|
|
|||
|
|
# 准备数据集
|
|||
|
|
all_samples = normal_samples + abnormal_samples
|
|||
|
|
all_labels = [0] * len(normal_samples) + [1] * len(abnormal_samples)
|
|||
|
|
|
|||
|
|
print(f"收集到的正常样本数: {len(normal_samples)}")
|
|||
|
|
print(f"收集到的异常样本数: {len(abnormal_samples)}")
|
|||
|
|
print(f"总样本数: {len(all_samples)}")
|
|||
|
|
|
|||
|
|
# 划分训练集和验证集
|
|||
|
|
indices = np.arange(len(all_samples))
|
|||
|
|
train_indices, val_indices = train_test_split(
|
|||
|
|
indices,
|
|||
|
|
test_size=(1-train_ratio),
|
|||
|
|
stratify=all_labels, # 确保训练集和验证集中类别比例一致
|
|||
|
|
random_state=42
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 准备训练集和验证集
|
|||
|
|
train_samples = [all_samples[i] for i in train_indices]
|
|||
|
|
train_labels = [all_labels[i] for i in train_indices]
|
|||
|
|
|
|||
|
|
val_samples = [all_samples[i] for i in val_indices]
|
|||
|
|
val_labels = [all_labels[i] for i in val_indices]
|
|||
|
|
|
|||
|
|
# 创建数据集
|
|||
|
|
train_dataset = LeukemiaDataset(train_samples, train_labels, transform)
|
|||
|
|
val_dataset = LeukemiaDataset(val_samples, val_labels, transform)
|
|||
|
|
|
|||
|
|
# 创建数据加载器
|
|||
|
|
train_loader = DataLoader(
|
|||
|
|
train_dataset,
|
|||
|
|
batch_size=batch_size,
|
|||
|
|
shuffle=True,
|
|||
|
|
num_workers=num_workers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
val_loader = DataLoader(
|
|||
|
|
val_dataset,
|
|||
|
|
batch_size=batch_size,
|
|||
|
|
shuffle=False,
|
|||
|
|
num_workers=num_workers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"训练集样本数: {len(train_dataset)}, 验证集样本数: {len(val_dataset)}")
|
|||
|
|
|
|||
|
|
return train_loader, val_loader
|