leukemia/config.py

36 lines
909 B
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
# 路径配置
DATA_ROOT = "./data" # 数据根目录包含000,001等子文件夹
OUTPUT_DIR = "./output"
MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models")
# 确保目录存在
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
# 数据配置
IMG_SIZE = 224 # 调整图像大小
BATCH_SIZE = 32
NUM_WORKERS = 4
TRAIN_RATIO = 0.8
VAL_RATIO = 0.2
# 样本数量控制
MAX_SAMPLES_PER_CLASS = 1000 # 每个类别最多读取的样本数
NORMAL_CLASS = "000" # 正常类别的文件夹名
ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"] # 异常类别的文件夹名
# 模型配置
HIDDEN_DIM = 768
NUM_HEADS = 12
NUM_LAYERS = 6
DROPOUT = 0.1
NUM_CLASSES = 2 # 二分类:正常 vs 异常
# 训练配置
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
NUM_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10