添加了loss函数曲线显示
This commit is contained in:
parent
b8468e955b
commit
40bdf03e5a
@ -24,7 +24,7 @@ class ChestXrayDataset(Dataset):
|
|||||||
|
|
||||||
return image, label
|
return image, label
|
||||||
|
|
||||||
def prepare_data(data_dir, batch_size=32):
|
def prepare_data(data_dir, batch_size=32): # 调低 batch_size
|
||||||
# 获取所有图片文件路径
|
# 获取所有图片文件路径
|
||||||
normal_dir = os.path.join(data_dir, 'normal')
|
normal_dir = os.path.join(data_dir, 'normal')
|
||||||
pneumonia_dir = os.path.join(data_dir, 'pneumonia')
|
pneumonia_dir = os.path.join(data_dir, 'pneumonia')
|
||||||
@ -67,7 +67,7 @@ def prepare_data(data_dir, batch_size=32):
|
|||||||
# 创建数据加载器
|
# 创建数据加载器
|
||||||
train_loader = DataLoader(
|
train_loader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size, # 使用调低后的 batch_size
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
pin_memory=True
|
pin_memory=True
|
||||||
@ -75,10 +75,10 @@ def prepare_data(data_dir, batch_size=32):
|
|||||||
|
|
||||||
test_loader = DataLoader(
|
test_loader = DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size, # 使用调低后的 batch_size
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=4,
|
num_workers=4,
|
||||||
pin_memory=True
|
pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_loader, test_loader
|
return train_loader, test_loader
|
||||||
15
src/train.py
15
src/train.py
@ -5,6 +5,7 @@ from model import MultiResRibNet
|
|||||||
from dataset import prepare_data
|
from dataset import prepare_data
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
def train(model, train_loader, test_loader, num_epochs=50):
|
def train(model, train_loader, test_loader, num_epochs=50):
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@ -15,6 +16,8 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
|||||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
|
||||||
|
|
||||||
best_acc = 0.0
|
best_acc = 0.0
|
||||||
|
train_losses = []
|
||||||
|
test_losses = []
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(num_epochs):
|
||||||
# 训练阶段
|
# 训练阶段
|
||||||
@ -45,6 +48,7 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
|||||||
|
|
||||||
train_loss = running_loss / len(train_loader)
|
train_loss = running_loss / len(train_loader)
|
||||||
train_acc = 100. * correct / total
|
train_acc = 100. * correct / total
|
||||||
|
train_losses.append(train_loss)
|
||||||
|
|
||||||
# 测试阶段
|
# 测试阶段
|
||||||
model.eval()
|
model.eval()
|
||||||
@ -65,6 +69,7 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
|||||||
|
|
||||||
test_loss = test_loss / len(test_loader)
|
test_loss = test_loss / len(test_loader)
|
||||||
test_acc = 100. * correct / total
|
test_acc = 100. * correct / total
|
||||||
|
test_losses.append(test_loss)
|
||||||
|
|
||||||
scheduler.step(test_loss)
|
scheduler.step(test_loss)
|
||||||
|
|
||||||
@ -76,6 +81,16 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
|||||||
if test_acc > best_acc:
|
if test_acc > best_acc:
|
||||||
best_acc = test_acc
|
best_acc = test_acc
|
||||||
torch.save(model.state_dict(), 'best_model.pth')
|
torch.save(model.state_dict(), 'best_model.pth')
|
||||||
|
|
||||||
|
# 绘制损失图
|
||||||
|
plt.figure(figsize=(10, 5))
|
||||||
|
plt.plot(train_losses, label='Train Loss')
|
||||||
|
plt.plot(test_losses, label='Test Loss')
|
||||||
|
plt.xlabel('Epoch')
|
||||||
|
plt.ylabel('Loss')
|
||||||
|
plt.title('Training and Test Loss')
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 数据路径
|
# 数据路径
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user