添加了loss函数曲线显示
This commit is contained in:
		
							parent
							
								
									b8468e955b
								
							
						
					
					
						commit
						40bdf03e5a
					
				@ -24,7 +24,7 @@ class ChestXrayDataset(Dataset):
 | 
			
		||||
 | 
			
		||||
        return image, label
 | 
			
		||||
 | 
			
		||||
def prepare_data(data_dir, batch_size=32):
 | 
			
		||||
def prepare_data(data_dir, batch_size=32):  # 调低 batch_size
 | 
			
		||||
    # 获取所有图片文件路径
 | 
			
		||||
    normal_dir = os.path.join(data_dir, 'normal')
 | 
			
		||||
    pneumonia_dir = os.path.join(data_dir, 'pneumonia')
 | 
			
		||||
@ -67,7 +67,7 @@ def prepare_data(data_dir, batch_size=32):
 | 
			
		||||
    # 创建数据加载器
 | 
			
		||||
    train_loader = DataLoader(
 | 
			
		||||
        train_dataset, 
 | 
			
		||||
        batch_size=batch_size, 
 | 
			
		||||
        batch_size=batch_size,  # 使用调低后的 batch_size
 | 
			
		||||
        shuffle=True, 
 | 
			
		||||
        num_workers=4,
 | 
			
		||||
        pin_memory=True
 | 
			
		||||
@ -75,10 +75,10 @@ def prepare_data(data_dir, batch_size=32):
 | 
			
		||||
    
 | 
			
		||||
    test_loader = DataLoader(
 | 
			
		||||
        test_dataset, 
 | 
			
		||||
        batch_size=batch_size, 
 | 
			
		||||
        batch_size=batch_size,  # 使用调低后的 batch_size
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        num_workers=4,
 | 
			
		||||
        pin_memory=True
 | 
			
		||||
    )
 | 
			
		||||
    
 | 
			
		||||
    return train_loader, test_loader
 | 
			
		||||
    return train_loader, test_loader
 | 
			
		||||
							
								
								
									
										15
									
								
								src/train.py
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								src/train.py
									
									
									
									
									
								
							@ -5,6 +5,7 @@ from model import MultiResRibNet
 | 
			
		||||
from dataset import prepare_data
 | 
			
		||||
import os
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
def train(model, train_loader, test_loader, num_epochs=50):
 | 
			
		||||
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
@ -15,6 +16,8 @@ def train(model, train_loader, test_loader, num_epochs=50):
 | 
			
		||||
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
 | 
			
		||||
    
 | 
			
		||||
    best_acc = 0.0
 | 
			
		||||
    train_losses = []
 | 
			
		||||
    test_losses = []
 | 
			
		||||
    
 | 
			
		||||
    for epoch in range(num_epochs):
 | 
			
		||||
        # 训练阶段
 | 
			
		||||
@ -45,6 +48,7 @@ def train(model, train_loader, test_loader, num_epochs=50):
 | 
			
		||||
        
 | 
			
		||||
        train_loss = running_loss / len(train_loader)
 | 
			
		||||
        train_acc = 100. * correct / total
 | 
			
		||||
        train_losses.append(train_loss)
 | 
			
		||||
        
 | 
			
		||||
        # 测试阶段
 | 
			
		||||
        model.eval()
 | 
			
		||||
@ -65,6 +69,7 @@ def train(model, train_loader, test_loader, num_epochs=50):
 | 
			
		||||
        
 | 
			
		||||
        test_loss = test_loss / len(test_loader)
 | 
			
		||||
        test_acc = 100. * correct / total
 | 
			
		||||
        test_losses.append(test_loss)
 | 
			
		||||
        
 | 
			
		||||
        scheduler.step(test_loss)
 | 
			
		||||
        
 | 
			
		||||
@ -76,6 +81,16 @@ def train(model, train_loader, test_loader, num_epochs=50):
 | 
			
		||||
        if test_acc > best_acc:
 | 
			
		||||
            best_acc = test_acc
 | 
			
		||||
            torch.save(model.state_dict(), 'best_model.pth')
 | 
			
		||||
    
 | 
			
		||||
    # 绘制损失图
 | 
			
		||||
    plt.figure(figsize=(10, 5))
 | 
			
		||||
    plt.plot(train_losses, label='Train Loss')
 | 
			
		||||
    plt.plot(test_losses, label='Test Loss')
 | 
			
		||||
    plt.xlabel('Epoch')
 | 
			
		||||
    plt.ylabel('Loss')
 | 
			
		||||
    plt.title('Training and Test Loss')
 | 
			
		||||
    plt.legend()
 | 
			
		||||
    plt.show()
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    # 数据路径
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user