添加了loss函数曲线显示

This commit is contained in:
lotus 2024-11-22 16:46:29 +08:00
parent b8468e955b
commit 40bdf03e5a
2 changed files with 19 additions and 4 deletions

View File

@ -24,7 +24,7 @@ class ChestXrayDataset(Dataset):
return image, label
def prepare_data(data_dir, batch_size=32):
def prepare_data(data_dir, batch_size=32): # 调低 batch_size
# 获取所有图片文件路径
normal_dir = os.path.join(data_dir, 'normal')
pneumonia_dir = os.path.join(data_dir, 'pneumonia')
@ -67,7 +67,7 @@ def prepare_data(data_dir, batch_size=32):
# 创建数据加载器
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
batch_size=batch_size, # 使用调低后的 batch_size
shuffle=True,
num_workers=4,
pin_memory=True
@ -75,10 +75,10 @@ def prepare_data(data_dir, batch_size=32):
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
batch_size=batch_size, # 使用调低后的 batch_size
shuffle=False,
num_workers=4,
pin_memory=True
)
return train_loader, test_loader
return train_loader, test_loader

View File

@ -5,6 +5,7 @@ from model import MultiResRibNet
from dataset import prepare_data
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
def train(model, train_loader, test_loader, num_epochs=50):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -15,6 +16,8 @@ def train(model, train_loader, test_loader, num_epochs=50):
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
best_acc = 0.0
train_losses = []
test_losses = []
for epoch in range(num_epochs):
# 训练阶段
@ -45,6 +48,7 @@ def train(model, train_loader, test_loader, num_epochs=50):
train_loss = running_loss / len(train_loader)
train_acc = 100. * correct / total
train_losses.append(train_loss)
# 测试阶段
model.eval()
@ -65,6 +69,7 @@ def train(model, train_loader, test_loader, num_epochs=50):
test_loss = test_loss / len(test_loader)
test_acc = 100. * correct / total
test_losses.append(test_loss)
scheduler.step(test_loss)
@ -76,6 +81,16 @@ def train(model, train_loader, test_loader, num_epochs=50):
if test_acc > best_acc:
best_acc = test_acc
torch.save(model.state_dict(), 'best_model.pth')
# 绘制损失图
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Test Loss')
plt.legend()
plt.show()
if __name__ == '__main__':
# 数据路径