171 lines
5.8 KiB
Python
171 lines
5.8 KiB
Python
import os
|
||
import glob
|
||
import random
|
||
import numpy as np
|
||
from PIL import Image
|
||
import torch
|
||
from torch.utils.data import Dataset, DataLoader
|
||
from torchvision import transforms
|
||
from sklearn.model_selection import train_test_split
|
||
|
||
|
||
class LeukemiaDataset(Dataset):
|
||
def __init__(self, sample_paths, labels, transform=None):
|
||
"""
|
||
初始化白血病数据集
|
||
|
||
Args:
|
||
sample_paths: 样本路径列表,每个元素是(diff_path, wnb_path)元组
|
||
labels: 对应的标签列表 (0: 正常, 1: 异常)
|
||
transform: 图像转换操作
|
||
"""
|
||
self.sample_paths = sample_paths
|
||
self.labels = labels
|
||
self.transform = transform
|
||
|
||
def __len__(self):
|
||
return len(self.sample_paths)
|
||
|
||
def __getitem__(self, idx):
|
||
diff_path, wnb_path = self.sample_paths[idx]
|
||
|
||
# 从路径中提取样本ID
|
||
sample_id = os.path.basename(diff_path).replace('Diff.png', '')
|
||
|
||
# 加载DIFF散点图
|
||
diff_img = Image.open(diff_path).convert('RGB')
|
||
|
||
# 加载WNB散点图
|
||
wnb_img = Image.open(wnb_path).convert('RGB')
|
||
|
||
if self.transform:
|
||
diff_img = self.transform(diff_img)
|
||
wnb_img = self.transform(wnb_img)
|
||
|
||
label = self.labels[idx]
|
||
|
||
return {
|
||
'id': sample_id,
|
||
'diff_img': diff_img,
|
||
'wnb_img': wnb_img,
|
||
'label': torch.tensor(label, dtype=torch.long)
|
||
}
|
||
|
||
|
||
def load_data(data_root, img_size=224, batch_size=32, num_workers=4, train_ratio=0.8,
|
||
max_samples_per_class=500, normal_class="000", abnormal_classes=None):
|
||
"""
|
||
加载和预处理数据
|
||
|
||
Args:
|
||
data_root: 数据根目录,包含000,001等子文件夹
|
||
img_size: 调整图像大小
|
||
batch_size: 批量大小
|
||
num_workers: 数据加载的工作线程数
|
||
train_ratio: 训练集比例
|
||
max_samples_per_class: 每个类别最多读取的样本数
|
||
normal_class: 正常类别的文件夹名
|
||
abnormal_classes: 异常类别的文件夹名列表
|
||
|
||
Returns:
|
||
train_loader: 训练数据加载器
|
||
val_loader: 验证数据加载器
|
||
"""
|
||
if abnormal_classes is None:
|
||
abnormal_classes = ["001", "010", "011", "100", "101", "110", "111"]
|
||
|
||
# 图像转换
|
||
transform = transforms.Compose([
|
||
transforms.Resize((img_size, img_size)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||
])
|
||
|
||
# 收集正常样本数据
|
||
normal_samples = []
|
||
normal_folder = os.path.join(data_root, normal_class)
|
||
diff_files = glob.glob(os.path.join(normal_folder, "*Diff.png"))
|
||
|
||
# 限制正常样本数量
|
||
if len(diff_files) > max_samples_per_class:
|
||
random.shuffle(diff_files)
|
||
diff_files = diff_files[:max_samples_per_class]
|
||
|
||
for diff_file in diff_files:
|
||
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
|
||
wnb_file = os.path.join(normal_folder, f"{sample_id}Wnb.png")
|
||
|
||
# 确保WNB文件存在
|
||
if os.path.exists(wnb_file):
|
||
normal_samples.append((diff_file, wnb_file))
|
||
|
||
# 收集异常样本数据
|
||
abnormal_samples = []
|
||
|
||
# 计算每个异常类别分配的样本数
|
||
# 如果max_samples_per_class=500,总共取500个异常样本,平均分配给各个异常类别
|
||
samples_per_abnormal_class = max_samples_per_class // len(abnormal_classes)
|
||
|
||
for abnormal_class in abnormal_classes:
|
||
abnormal_folder = os.path.join(data_root, abnormal_class)
|
||
diff_files = glob.glob(os.path.join(abnormal_folder, "*Diff.png"))
|
||
|
||
# 限制每个异常类别的样本数量
|
||
if len(diff_files) > samples_per_abnormal_class:
|
||
random.shuffle(diff_files)
|
||
diff_files = diff_files[:samples_per_abnormal_class]
|
||
|
||
for diff_file in diff_files:
|
||
sample_id = os.path.basename(diff_file).replace('Diff.png', '')
|
||
wnb_file = os.path.join(abnormal_folder, f"{sample_id}Wnb.png")
|
||
|
||
# 确保WNB文件存在
|
||
if os.path.exists(wnb_file):
|
||
abnormal_samples.append((diff_file, wnb_file))
|
||
|
||
# 准备数据集
|
||
all_samples = normal_samples + abnormal_samples
|
||
all_labels = [0] * len(normal_samples) + [1] * len(abnormal_samples)
|
||
|
||
print(f"收集到的正常样本数: {len(normal_samples)}")
|
||
print(f"收集到的异常样本数: {len(abnormal_samples)}")
|
||
print(f"总样本数: {len(all_samples)}")
|
||
|
||
# 划分训练集和验证集
|
||
indices = np.arange(len(all_samples))
|
||
train_indices, val_indices = train_test_split(
|
||
indices,
|
||
test_size=(1-train_ratio),
|
||
stratify=all_labels, # 确保训练集和验证集中类别比例一致
|
||
random_state=42
|
||
)
|
||
|
||
# 准备训练集和验证集
|
||
train_samples = [all_samples[i] for i in train_indices]
|
||
train_labels = [all_labels[i] for i in train_indices]
|
||
|
||
val_samples = [all_samples[i] for i in val_indices]
|
||
val_labels = [all_labels[i] for i in val_indices]
|
||
|
||
# 创建数据集
|
||
train_dataset = LeukemiaDataset(train_samples, train_labels, transform)
|
||
val_dataset = LeukemiaDataset(val_samples, val_labels, transform)
|
||
|
||
# 创建数据加载器
|
||
train_loader = DataLoader(
|
||
train_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=True,
|
||
num_workers=num_workers
|
||
)
|
||
|
||
val_loader = DataLoader(
|
||
val_dataset,
|
||
batch_size=batch_size,
|
||
shuffle=False,
|
||
num_workers=num_workers
|
||
)
|
||
|
||
print(f"训练集样本数: {len(train_dataset)}, 验证集样本数: {len(val_dataset)}")
|
||
|
||
return train_loader, val_loader |