import os import glob import random import numpy as np from PIL import Image import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from sklearn.model_selection import train_test_split class LeukemiaDataset(Dataset): def __init__(self, sample_paths, labels, transform=None): """ 初始化白血病数据集 Args: sample_paths: 样本路径列表,每个元素是(diff_path, wnb_path)元组 labels: 对应的标签列表 (0: 正常, 1: 异常) transform: 图像转换操作 """ self.sample_paths = sample_paths self.labels = labels self.transform = transform def __len__(self): return len(self.sample_paths) def __getitem__(self, idx): diff_path, wnb_path = self.sample_paths[idx] # 从路径中提取样本ID sample_id = os.path.basename(diff_path).replace('Diff.png', '') # 加载DIFF散点图 diff_img = Image.open(diff_path).convert('RGB') # 加载WNB散点图 wnb_img = Image.open(wnb_path).convert('RGB') if self.transform: diff_img = self.transform(diff_img) wnb_img = self.transform(wnb_img) label = self.labels[idx] return { 'id': sample_id, 'diff_img': diff_img, 'wnb_img': wnb_img, 'label': torch.tensor(label, dtype=torch.long) } def load_data(data_root, img_size=224, batch_size=32, num_workers=4, train_ratio=0.8, max_samples_per_class=500, normal_class="000", abnormal_classes=None): """ 加载和预处理数据 Args: data_root: 数据根目录,包含000,001等子文件夹 img_size: 调整图像大小 batch_size: 批量大小 num_workers: 数据加载的工作线程数 train_ratio: 训练集比例 max_samples_per_class: 每个类别最多读取的样本数 normal_class: 正常类别的文件夹名 abnormal_classes: 异常类别的文件夹名列表 Returns: train_loader: 训练数据加载器 val_loader: 验证数据加载器 """ if abnormal_classes is None: abnormal_classes = ["001", "010", "011", "100", "101", "110", "111"] # 图像转换 transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 收集正常样本数据 normal_samples = [] normal_folder = os.path.join(data_root, normal_class) diff_files = glob.glob(os.path.join(normal_folder, "*Diff.png")) # 限制正常样本数量 if len(diff_files) > max_samples_per_class: random.shuffle(diff_files) diff_files = diff_files[:max_samples_per_class] for diff_file in diff_files: sample_id = os.path.basename(diff_file).replace('Diff.png', '') wnb_file = os.path.join(normal_folder, f"{sample_id}Wnb.png") # 确保WNB文件存在 if os.path.exists(wnb_file): normal_samples.append((diff_file, wnb_file)) # 收集异常样本数据 abnormal_samples = [] # 计算每个异常类别分配的样本数 # 如果max_samples_per_class=500,总共取500个异常样本,平均分配给各个异常类别 samples_per_abnormal_class = max_samples_per_class // len(abnormal_classes) for abnormal_class in abnormal_classes: abnormal_folder = os.path.join(data_root, abnormal_class) diff_files = glob.glob(os.path.join(abnormal_folder, "*Diff.png")) # 限制每个异常类别的样本数量 if len(diff_files) > samples_per_abnormal_class: random.shuffle(diff_files) diff_files = diff_files[:samples_per_abnormal_class] for diff_file in diff_files: sample_id = os.path.basename(diff_file).replace('Diff.png', '') wnb_file = os.path.join(abnormal_folder, f"{sample_id}Wnb.png") # 确保WNB文件存在 if os.path.exists(wnb_file): abnormal_samples.append((diff_file, wnb_file)) # 准备数据集 all_samples = normal_samples + abnormal_samples all_labels = [0] * len(normal_samples) + [1] * len(abnormal_samples) print(f"收集到的正常样本数: {len(normal_samples)}") print(f"收集到的异常样本数: {len(abnormal_samples)}") print(f"总样本数: {len(all_samples)}") # 划分训练集和验证集 indices = np.arange(len(all_samples)) train_indices, val_indices = train_test_split( indices, test_size=(1-train_ratio), stratify=all_labels, # 确保训练集和验证集中类别比例一致 random_state=42 ) # 准备训练集和验证集 train_samples = [all_samples[i] for i in train_indices] train_labels = [all_labels[i] for i in train_indices] val_samples = [all_samples[i] for i in val_indices] val_labels = [all_labels[i] for i in val_indices] # 创建数据集 train_dataset = LeukemiaDataset(train_samples, train_labels, transform) val_dataset = LeukemiaDataset(val_samples, val_labels, transform) # 创建数据加载器 train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers ) print(f"训练集样本数: {len(train_dataset)}, 验证集样本数: {len(val_dataset)}") return train_loader, val_loader