From 40bdf03e5a925355f96069c38478dbd6e430f5ad Mon Sep 17 00:00:00 2001 From: lotus Date: Fri, 22 Nov 2024 16:46:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86loss=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E6=9B=B2=E7=BA=BF=E6=98=BE=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/dataset.py | 8 ++++---- src/train.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/src/dataset.py b/src/dataset.py index 7df3a4c..fdc31f6 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -24,7 +24,7 @@ class ChestXrayDataset(Dataset): return image, label -def prepare_data(data_dir, batch_size=32): +def prepare_data(data_dir, batch_size=32): # 调低 batch_size # 获取所有图片文件路径 normal_dir = os.path.join(data_dir, 'normal') pneumonia_dir = os.path.join(data_dir, 'pneumonia') @@ -67,7 +67,7 @@ def prepare_data(data_dir, batch_size=32): # 创建数据加载器 train_loader = DataLoader( train_dataset, - batch_size=batch_size, + batch_size=batch_size, # 使用调低后的 batch_size shuffle=True, num_workers=4, pin_memory=True @@ -75,10 +75,10 @@ def prepare_data(data_dir, batch_size=32): test_loader = DataLoader( test_dataset, - batch_size=batch_size, + batch_size=batch_size, # 使用调低后的 batch_size shuffle=False, num_workers=4, pin_memory=True ) - return train_loader, test_loader + return train_loader, test_loader \ No newline at end of file diff --git a/src/train.py b/src/train.py index a7e9797..7fc27ed 100644 --- a/src/train.py +++ b/src/train.py @@ -5,6 +5,7 @@ from model import MultiResRibNet from dataset import prepare_data import os from tqdm import tqdm +import matplotlib.pyplot as plt def train(model, train_loader, test_loader, num_epochs=50): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -15,6 +16,8 @@ def train(model, train_loader, test_loader, num_epochs=50): scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3) best_acc = 0.0 + train_losses = [] + test_losses = [] for epoch in range(num_epochs): # 训练阶段 @@ -45,6 +48,7 @@ def train(model, train_loader, test_loader, num_epochs=50): train_loss = running_loss / len(train_loader) train_acc = 100. * correct / total + train_losses.append(train_loss) # 测试阶段 model.eval() @@ -65,6 +69,7 @@ def train(model, train_loader, test_loader, num_epochs=50): test_loss = test_loss / len(test_loader) test_acc = 100. * correct / total + test_losses.append(test_loss) scheduler.step(test_loss) @@ -76,6 +81,16 @@ def train(model, train_loader, test_loader, num_epochs=50): if test_acc > best_acc: best_acc = test_acc torch.save(model.state_dict(), 'best_model.pth') + + # 绘制损失图 + plt.figure(figsize=(10, 5)) + plt.plot(train_losses, label='Train Loss') + plt.plot(test_losses, label='Test Loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.title('Training and Test Loss') + plt.legend() + plt.show() if __name__ == '__main__': # 数据路径