添加了loss函数曲线显示
This commit is contained in:
parent
b8468e955b
commit
40bdf03e5a
@ -24,7 +24,7 @@ class ChestXrayDataset(Dataset):
|
||||
|
||||
return image, label
|
||||
|
||||
def prepare_data(data_dir, batch_size=32):
|
||||
def prepare_data(data_dir, batch_size=32): # 调低 batch_size
|
||||
# 获取所有图片文件路径
|
||||
normal_dir = os.path.join(data_dir, 'normal')
|
||||
pneumonia_dir = os.path.join(data_dir, 'pneumonia')
|
||||
@ -67,7 +67,7 @@ def prepare_data(data_dir, batch_size=32):
|
||||
# 创建数据加载器
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size, # 使用调低后的 batch_size
|
||||
shuffle=True,
|
||||
num_workers=4,
|
||||
pin_memory=True
|
||||
@ -75,10 +75,10 @@ def prepare_data(data_dir, batch_size=32):
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch_size, # 使用调低后的 batch_size
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
return train_loader, test_loader
|
||||
return train_loader, test_loader
|
||||
15
src/train.py
15
src/train.py
@ -5,6 +5,7 @@ from model import MultiResRibNet
|
||||
from dataset import prepare_data
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def train(model, train_loader, test_loader, num_epochs=50):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@ -15,6 +16,8 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
|
||||
|
||||
best_acc = 0.0
|
||||
train_losses = []
|
||||
test_losses = []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
# 训练阶段
|
||||
@ -45,6 +48,7 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
||||
|
||||
train_loss = running_loss / len(train_loader)
|
||||
train_acc = 100. * correct / total
|
||||
train_losses.append(train_loss)
|
||||
|
||||
# 测试阶段
|
||||
model.eval()
|
||||
@ -65,6 +69,7 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
||||
|
||||
test_loss = test_loss / len(test_loader)
|
||||
test_acc = 100. * correct / total
|
||||
test_losses.append(test_loss)
|
||||
|
||||
scheduler.step(test_loss)
|
||||
|
||||
@ -76,6 +81,16 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
||||
if test_acc > best_acc:
|
||||
best_acc = test_acc
|
||||
torch.save(model.state_dict(), 'best_model.pth')
|
||||
|
||||
# 绘制损失图
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(train_losses, label='Train Loss')
|
||||
plt.plot(test_losses, label='Test Loss')
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Loss')
|
||||
plt.title('Training and Test Loss')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 数据路径
|
||||
|
||||
Loading…
Reference in New Issue
Block a user