leukemia/data_preprocessing/data_loader.py

171 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import glob
import random
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
class LeukemiaDataset(Dataset):
def __init__(self, sample_paths, labels, transform=None):
"""
初始化白血病数据集
Args:
sample_paths: 样本路径列表,每个元素是(diff_path, wnb_path)元组
labels: 对应的标签列表 (0: 正常, 1: 异常)
transform: 图像转换操作
"""
self.sample_paths = sample_paths
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.sample_paths)
def __getitem__(self, idx):
diff_path, wnb_path = self.sample_paths[idx]
# 从路径中提取样本ID
sample_id = os.path.basename(diff_path).replace('Diff.png', '')
# 加载DIFF散点图
diff_img = Image.open(diff_path).convert('RGB')
# 加载WNB散点图
wnb_img = Image.open(wnb_path).convert('RGB')
if self.transform:
diff_img = self.transform(diff_img)
wnb_img = self.transform(wnb_img)
label = self.labels[idx]
return {
'id': sample_id,
'diff_img': diff_img,
'wnb_img': wnb_img,
'label': torch.tensor(label, dtype=torch.long)
}
def load_data(data_root, img_size=224, batch_size=32, num_workers=4, train_ratio=0.8,
max_samples_per_class=500, normal_class="000", abnormal_classes=None):
"""
加载和预处理数据
Args:
data_root: 数据根目录包含000,001等子文件夹
img_size: 调整图像大小
batch_size: 批量大小
num_workers: 数据加载的工作线程数
train_ratio: 训练集比例
max_samples_per_class: 每个类别最多读取的样本数
normal_class: 正常类别的文件夹名
abnormal_classes: 异常类别的文件夹名列表
Returns:
train_loader: 训练数据加载器
val_loader: 验证数据加载器
"""
if abnormal_classes is None:
abnormal_classes = ["001", "010", "011", "100", "101", "110", "111"]
# 图像转换
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 收集正常样本数据
normal_samples = []
normal_folder = os.path.join(data_root, normal_class)
diff_files = glob.glob(os.path.join(normal_folder, "*Diff.png"))
# 限制正常样本数量
if len(diff_files) > max_samples_per_class:
random.shuffle(diff_files)
diff_files = diff_files[:max_samples_per_class]
for diff_file in diff_files:
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
wnb_file = os.path.join(normal_folder, f"{sample_id}Wnb.png")
# 确保WNB文件存在
if os.path.exists(wnb_file):
normal_samples.append((diff_file, wnb_file))
# 收集异常样本数据
abnormal_samples = []
# 计算每个异常类别分配的样本数
# 如果max_samples_per_class=500总共取500个异常样本平均分配给各个异常类别
samples_per_abnormal_class = max_samples_per_class // len(abnormal_classes)
for abnormal_class in abnormal_classes:
abnormal_folder = os.path.join(data_root, abnormal_class)
diff_files = glob.glob(os.path.join(abnormal_folder, "*Diff.png"))
# 限制每个异常类别的样本数量
if len(diff_files) > samples_per_abnormal_class:
random.shuffle(diff_files)
diff_files = diff_files[:samples_per_abnormal_class]
for diff_file in diff_files:
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
wnb_file = os.path.join(abnormal_folder, f"{sample_id}Wnb.png")
# 确保WNB文件存在
if os.path.exists(wnb_file):
abnormal_samples.append((diff_file, wnb_file))
# 准备数据集
all_samples = normal_samples + abnormal_samples
all_labels = [0] * len(normal_samples) + [1] * len(abnormal_samples)
print(f"收集到的正常样本数: {len(normal_samples)}")
print(f"收集到的异常样本数: {len(abnormal_samples)}")
print(f"总样本数: {len(all_samples)}")
# 划分训练集和验证集
indices = np.arange(len(all_samples))
train_indices, val_indices = train_test_split(
indices,
test_size=(1-train_ratio),
stratify=all_labels, # 确保训练集和验证集中类别比例一致
random_state=42
)
# 准备训练集和验证集
train_samples = [all_samples[i] for i in train_indices]
train_labels = [all_labels[i] for i in train_indices]
val_samples = [all_samples[i] for i in val_indices]
val_labels = [all_labels[i] for i in val_indices]
# 创建数据集
train_dataset = LeukemiaDataset(train_samples, train_labels, transform)
val_dataset = LeukemiaDataset(val_samples, val_labels, transform)
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers
)
print(f"训练集样本数: {len(train_dataset)}, 验证集样本数: {len(val_dataset)}")
return train_loader, val_loader