36 lines
909 B
Python
36 lines
909 B
Python
import os
|
||
import torch
|
||
|
||
# 路径配置
|
||
DATA_ROOT = "./data" # 数据根目录,包含000,001等子文件夹
|
||
OUTPUT_DIR = "./output"
|
||
MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models")
|
||
|
||
# 确保目录存在
|
||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||
|
||
# 数据配置
|
||
IMG_SIZE = 224 # 调整图像大小
|
||
BATCH_SIZE = 32
|
||
NUM_WORKERS = 4
|
||
TRAIN_RATIO = 0.8
|
||
VAL_RATIO = 0.2
|
||
|
||
# 样本数量控制
|
||
MAX_SAMPLES_PER_CLASS = 1000 # 每个类别最多读取的样本数
|
||
NORMAL_CLASS = "000" # 正常类别的文件夹名
|
||
ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"] # 异常类别的文件夹名
|
||
|
||
# 模型配置
|
||
HIDDEN_DIM = 768
|
||
NUM_HEADS = 12
|
||
NUM_LAYERS = 6
|
||
DROPOUT = 0.1
|
||
NUM_CLASSES = 2 # 二分类:正常 vs 异常
|
||
|
||
# 训练配置
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
LEARNING_RATE = 1e-4
|
||
WEIGHT_DECAY = 1e-5
|
||
NUM_EPOCHS = 50
|
||
EARLY_STOPPING_PATIENCE = 10 |