171 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			171 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						||
import glob
 | 
						||
import random
 | 
						||
import numpy as np
 | 
						||
from PIL import Image
 | 
						||
import torch
 | 
						||
from torch.utils.data import Dataset, DataLoader
 | 
						||
from torchvision import transforms
 | 
						||
from sklearn.model_selection import train_test_split
 | 
						||
 | 
						||
 | 
						||
class LeukemiaDataset(Dataset):
 | 
						||
    def __init__(self, sample_paths, labels, transform=None):
 | 
						||
        """
 | 
						||
        初始化白血病数据集
 | 
						||
        
 | 
						||
        Args:
 | 
						||
            sample_paths: 样本路径列表,每个元素是(diff_path, wnb_path)元组
 | 
						||
            labels: 对应的标签列表 (0: 正常, 1: 异常)
 | 
						||
            transform: 图像转换操作
 | 
						||
        """
 | 
						||
        self.sample_paths = sample_paths
 | 
						||
        self.labels = labels
 | 
						||
        self.transform = transform
 | 
						||
    
 | 
						||
    def __len__(self):
 | 
						||
        return len(self.sample_paths)
 | 
						||
    
 | 
						||
    def __getitem__(self, idx):
 | 
						||
        diff_path, wnb_path = self.sample_paths[idx]
 | 
						||
        
 | 
						||
        # 从路径中提取样本ID
 | 
						||
        sample_id = os.path.basename(diff_path).replace('Diff.png', '')
 | 
						||
        
 | 
						||
        # 加载DIFF散点图
 | 
						||
        diff_img = Image.open(diff_path).convert('RGB')
 | 
						||
        
 | 
						||
        # 加载WNB散点图
 | 
						||
        wnb_img = Image.open(wnb_path).convert('RGB')
 | 
						||
        
 | 
						||
        if self.transform:
 | 
						||
            diff_img = self.transform(diff_img)
 | 
						||
            wnb_img = self.transform(wnb_img)
 | 
						||
            
 | 
						||
        label = self.labels[idx]
 | 
						||
        
 | 
						||
        return {
 | 
						||
            'id': sample_id,
 | 
						||
            'diff_img': diff_img,
 | 
						||
            'wnb_img': wnb_img,
 | 
						||
            'label': torch.tensor(label, dtype=torch.long)
 | 
						||
        }
 | 
						||
 | 
						||
 | 
						||
def load_data(data_root, img_size=224, batch_size=32, num_workers=4, train_ratio=0.8, 
 | 
						||
              max_samples_per_class=500, normal_class="000", abnormal_classes=None):
 | 
						||
    """
 | 
						||
    加载和预处理数据
 | 
						||
    
 | 
						||
    Args:
 | 
						||
        data_root: 数据根目录,包含000,001等子文件夹
 | 
						||
        img_size: 调整图像大小
 | 
						||
        batch_size: 批量大小
 | 
						||
        num_workers: 数据加载的工作线程数
 | 
						||
        train_ratio: 训练集比例
 | 
						||
        max_samples_per_class: 每个类别最多读取的样本数
 | 
						||
        normal_class: 正常类别的文件夹名
 | 
						||
        abnormal_classes: 异常类别的文件夹名列表
 | 
						||
    
 | 
						||
    Returns:
 | 
						||
        train_loader: 训练数据加载器
 | 
						||
        val_loader: 验证数据加载器
 | 
						||
    """
 | 
						||
    if abnormal_classes is None:
 | 
						||
        abnormal_classes = ["001", "010", "011", "100", "101", "110", "111"]
 | 
						||
    
 | 
						||
    # 图像转换
 | 
						||
    transform = transforms.Compose([
 | 
						||
        transforms.Resize((img_size, img_size)),
 | 
						||
        transforms.ToTensor(),
 | 
						||
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 | 
						||
    ])
 | 
						||
    
 | 
						||
    # 收集正常样本数据
 | 
						||
    normal_samples = []
 | 
						||
    normal_folder = os.path.join(data_root, normal_class)
 | 
						||
    diff_files = glob.glob(os.path.join(normal_folder, "*Diff.png"))
 | 
						||
    
 | 
						||
    # 限制正常样本数量
 | 
						||
    if len(diff_files) > max_samples_per_class:
 | 
						||
        random.shuffle(diff_files)
 | 
						||
        diff_files = diff_files[:max_samples_per_class]
 | 
						||
    
 | 
						||
    for diff_file in diff_files:
 | 
						||
        sample_id = os.path.basename(diff_file).replace('Diff.png', '')
 | 
						||
        wnb_file = os.path.join(normal_folder, f"{sample_id}Wnb.png")
 | 
						||
        
 | 
						||
        # 确保WNB文件存在
 | 
						||
        if os.path.exists(wnb_file):
 | 
						||
            normal_samples.append((diff_file, wnb_file))
 | 
						||
    
 | 
						||
    # 收集异常样本数据
 | 
						||
    abnormal_samples = []
 | 
						||
    
 | 
						||
    # 计算每个异常类别分配的样本数
 | 
						||
    # 如果max_samples_per_class=500,总共取500个异常样本,平均分配给各个异常类别
 | 
						||
    samples_per_abnormal_class = max_samples_per_class // len(abnormal_classes)
 | 
						||
    
 | 
						||
    for abnormal_class in abnormal_classes:
 | 
						||
        abnormal_folder = os.path.join(data_root, abnormal_class)
 | 
						||
        diff_files = glob.glob(os.path.join(abnormal_folder, "*Diff.png"))
 | 
						||
        
 | 
						||
        # 限制每个异常类别的样本数量
 | 
						||
        if len(diff_files) > samples_per_abnormal_class:
 | 
						||
            random.shuffle(diff_files)
 | 
						||
            diff_files = diff_files[:samples_per_abnormal_class]
 | 
						||
        
 | 
						||
        for diff_file in diff_files:
 | 
						||
            sample_id = os.path.basename(diff_file).replace('Diff.png', '')
 | 
						||
            wnb_file = os.path.join(abnormal_folder, f"{sample_id}Wnb.png")
 | 
						||
            
 | 
						||
            # 确保WNB文件存在
 | 
						||
            if os.path.exists(wnb_file):
 | 
						||
                abnormal_samples.append((diff_file, wnb_file))
 | 
						||
    
 | 
						||
    # 准备数据集
 | 
						||
    all_samples = normal_samples + abnormal_samples
 | 
						||
    all_labels = [0] * len(normal_samples) + [1] * len(abnormal_samples)
 | 
						||
    
 | 
						||
    print(f"收集到的正常样本数: {len(normal_samples)}")
 | 
						||
    print(f"收集到的异常样本数: {len(abnormal_samples)}")
 | 
						||
    print(f"总样本数: {len(all_samples)}")
 | 
						||
    
 | 
						||
    # 划分训练集和验证集
 | 
						||
    indices = np.arange(len(all_samples))
 | 
						||
    train_indices, val_indices = train_test_split(
 | 
						||
        indices, 
 | 
						||
        test_size=(1-train_ratio),
 | 
						||
        stratify=all_labels,  # 确保训练集和验证集中类别比例一致
 | 
						||
        random_state=42
 | 
						||
    )
 | 
						||
    
 | 
						||
    # 准备训练集和验证集
 | 
						||
    train_samples = [all_samples[i] for i in train_indices]
 | 
						||
    train_labels = [all_labels[i] for i in train_indices]
 | 
						||
    
 | 
						||
    val_samples = [all_samples[i] for i in val_indices]
 | 
						||
    val_labels = [all_labels[i] for i in val_indices]
 | 
						||
    
 | 
						||
    # 创建数据集
 | 
						||
    train_dataset = LeukemiaDataset(train_samples, train_labels, transform)
 | 
						||
    val_dataset = LeukemiaDataset(val_samples, val_labels, transform)
 | 
						||
    
 | 
						||
    # 创建数据加载器
 | 
						||
    train_loader = DataLoader(
 | 
						||
        train_dataset, 
 | 
						||
        batch_size=batch_size,
 | 
						||
        shuffle=True,
 | 
						||
        num_workers=num_workers
 | 
						||
    )
 | 
						||
    
 | 
						||
    val_loader = DataLoader(
 | 
						||
        val_dataset, 
 | 
						||
        batch_size=batch_size,
 | 
						||
        shuffle=False,
 | 
						||
        num_workers=num_workers
 | 
						||
    )
 | 
						||
    
 | 
						||
    print(f"训练集样本数: {len(train_dataset)}, 验证集样本数: {len(val_dataset)}")
 | 
						||
    
 | 
						||
    return train_loader, val_loader |