添加了tqdm显示进度条

This commit is contained in:
lotus 2024-11-22 00:09:33 +08:00
parent 67c38df0d5
commit b8468e955b

View File

@ -4,6 +4,7 @@ import torch.optim as optim
from model import MultiResRibNet
from dataset import prepare_data
import os
from tqdm import tqdm
def train(model, train_loader, test_loader, num_epochs=50):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -22,7 +23,10 @@ def train(model, train_loader, test_loader, num_epochs=50):
correct = 0
total = 0
for inputs, labels in train_loader:
# 使用 tqdm 显示进度条
train_loader_tqdm = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}]')
for inputs, labels in train_loader_tqdm:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
@ -35,6 +39,9 @@ def train(model, train_loader, test_loader, num_epochs=50):
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
# 更新进度条的描述信息
train_loader_tqdm.set_postfix(loss=loss.item(), acc=100. * correct / total)
train_loss = running_loss / len(train_loader)
train_acc = 100. * correct / total
@ -81,6 +88,4 @@ if __name__ == '__main__':
model = MultiResRibNet()
# 训练模型
train(model, train_loader, test_loader)
train(model, train_loader, test_loader)