leukemia/config.py

36 lines
909 B
Python
Raw Permalink Normal View History

import os
import torch
# 路径配置
DATA_ROOT = "./data" # 数据根目录包含000,001等子文件夹
OUTPUT_DIR = "./output"
MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models")
# 确保目录存在
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
# 数据配置
IMG_SIZE = 224 # 调整图像大小
BATCH_SIZE = 32
NUM_WORKERS = 4
TRAIN_RATIO = 0.8
VAL_RATIO = 0.2
# 样本数量控制
MAX_SAMPLES_PER_CLASS = 1000 # 每个类别最多读取的样本数
NORMAL_CLASS = "000" # 正常类别的文件夹名
ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"] # 异常类别的文件夹名
# 模型配置
HIDDEN_DIM = 768
NUM_HEADS = 12
NUM_LAYERS = 6
DROPOUT = 0.1
NUM_CLASSES = 2 # 二分类:正常 vs 异常
# 训练配置
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
NUM_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10