36 lines
		
	
	
		
			909 B
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			36 lines
		
	
	
		
			909 B
		
	
	
	
		
			Python
		
	
	
	
	
	
| 
								 | 
							
								import os
							 | 
						|||
| 
								 | 
							
								import torch
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 路径配置
							 | 
						|||
| 
								 | 
							
								DATA_ROOT = "./data"  # 数据根目录,包含000,001等子文件夹
							 | 
						|||
| 
								 | 
							
								OUTPUT_DIR = "./output"
							 | 
						|||
| 
								 | 
							
								MODEL_SAVE_DIR = os.path.join(OUTPUT_DIR, "models")
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 确保目录存在
							 | 
						|||
| 
								 | 
							
								os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 数据配置
							 | 
						|||
| 
								 | 
							
								IMG_SIZE = 224  # 调整图像大小
							 | 
						|||
| 
								 | 
							
								BATCH_SIZE = 32
							 | 
						|||
| 
								 | 
							
								NUM_WORKERS = 4
							 | 
						|||
| 
								 | 
							
								TRAIN_RATIO = 0.8
							 | 
						|||
| 
								 | 
							
								VAL_RATIO = 0.2
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 样本数量控制
							 | 
						|||
| 
								 | 
							
								MAX_SAMPLES_PER_CLASS = 1000  # 每个类别最多读取的样本数
							 | 
						|||
| 
								 | 
							
								NORMAL_CLASS = "000"  # 正常类别的文件夹名
							 | 
						|||
| 
								 | 
							
								ABNORMAL_CLASSES = ["001", "010", "011", "100", "101", "110", "111"]  # 异常类别的文件夹名
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 模型配置
							 | 
						|||
| 
								 | 
							
								HIDDEN_DIM = 768
							 | 
						|||
| 
								 | 
							
								NUM_HEADS = 12
							 | 
						|||
| 
								 | 
							
								NUM_LAYERS = 6
							 | 
						|||
| 
								 | 
							
								DROPOUT = 0.1
							 | 
						|||
| 
								 | 
							
								NUM_CLASSES = 2  # 二分类:正常 vs 异常
							 | 
						|||
| 
								 | 
							
								
							 | 
						|||
| 
								 | 
							
								# 训练配置
							 | 
						|||
| 
								 | 
							
								DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
							 | 
						|||
| 
								 | 
							
								LEARNING_RATE = 1e-4
							 | 
						|||
| 
								 | 
							
								WEIGHT_DECAY = 1e-5
							 | 
						|||
| 
								 | 
							
								NUM_EPOCHS = 50
							 | 
						|||
| 
								 | 
							
								EARLY_STOPPING_PATIENCE = 10
							 |