36 lines
		
	
	
		
			909 B
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			36 lines
		
	
	
		
			909 B
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						||
import torch
 | 
						||
 | 
						||
# 路径配置
 | 
						||
DATA_ROOT = "./data"  # 数据根目录,包含000,001等子文件夹
 | 
						||
OUTPUT_DIR = "./output"
 | 
						||
MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models")
 | 
						||
 | 
						||
# 确保目录存在
 | 
						||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
 | 
						||
 | 
						||
# 数据配置
 | 
						||
IMG_SIZE = 224  # 调整图像大小
 | 
						||
BATCH_SIZE = 32
 | 
						||
NUM_WORKERS = 4
 | 
						||
TRAIN_RATIO = 0.8
 | 
						||
VAL_RATIO = 0.2
 | 
						||
 | 
						||
# 样本数量控制
 | 
						||
MAX_SAMPLES_PER_CLASS = 1000  # 每个类别最多读取的样本数
 | 
						||
NORMAL_CLASS = "000"  # 正常类别的文件夹名
 | 
						||
ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"]  # 异常类别的文件夹名
 | 
						||
 | 
						||
# 模型配置
 | 
						||
HIDDEN_DIM = 768
 | 
						||
NUM_HEADS = 12
 | 
						||
NUM_LAYERS = 6
 | 
						||
DROPOUT = 0.1
 | 
						||
NUM_CLASSES = 2  # 二分类:正常 vs 异常
 | 
						||
 | 
						||
# 训练配置
 | 
						||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
						||
LEARNING_RATE = 1e-4
 | 
						||
WEIGHT_DECAY = 1e-5
 | 
						||
NUM_EPOCHS = 50
 | 
						||
EARLY_STOPPING_PATIENCE = 10 |