36 lines
909 B
Python
36 lines
909 B
Python
|
|
import os
|
|||
|
|
import torch
|
|||
|
|
|
|||
|
|
# 路径配置
|
|||
|
|
DATA_ROOT = "./data" # 数据根目录,包含000,001等子文件夹
|
|||
|
|
OUTPUT_DIR = "./output"
|
|||
|
|
MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models")
|
|||
|
|
|
|||
|
|
# 确保目录存在
|
|||
|
|
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 数据配置
|
|||
|
|
IMG_SIZE = 224 # 调整图像大小
|
|||
|
|
BATCH_SIZE = 32
|
|||
|
|
NUM_WORKERS = 4
|
|||
|
|
TRAIN_RATIO = 0.8
|
|||
|
|
VAL_RATIO = 0.2
|
|||
|
|
|
|||
|
|
# 样本数量控制
|
|||
|
|
MAX_SAMPLES_PER_CLASS = 1000 # 每个类别最多读取的样本数
|
|||
|
|
NORMAL_CLASS = "000" # 正常类别的文件夹名
|
|||
|
|
ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"] # 异常类别的文件夹名
|
|||
|
|
|
|||
|
|
# 模型配置
|
|||
|
|
HIDDEN_DIM = 768
|
|||
|
|
NUM_HEADS = 12
|
|||
|
|
NUM_LAYERS = 6
|
|||
|
|
DROPOUT = 0.1
|
|||
|
|
NUM_CLASSES = 2 # 二分类:正常 vs 异常
|
|||
|
|
|
|||
|
|
# 训练配置
|
|||
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||
|
|
LEARNING_RATE = 1e-4
|
|||
|
|
WEIGHT_DECAY = 1e-5
|
|||
|
|
NUM_EPOCHS = 50
|
|||
|
|
EARLY_STOPPING_PATIENCE = 10
|