From b8468e955b5ec4bbead9b385534d4f3d6c6a7a7c Mon Sep 17 00:00:00 2001 From: lotus Date: Fri, 22 Nov 2024 00:09:33 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86tqdm=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E8=BF=9B=E5=BA=A6=E6=9D=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/train.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/train.py b/src/train.py index c462d6a..a7e9797 100644 --- a/src/train.py +++ b/src/train.py @@ -4,6 +4,7 @@ import torch.optim as optim from model import MultiResRibNet from dataset import prepare_data import os +from tqdm import tqdm def train(model, train_loader, test_loader, num_epochs=50): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -22,7 +23,10 @@ def train(model, train_loader, test_loader, num_epochs=50): correct = 0 total = 0 - for inputs, labels in train_loader: + # 使用 tqdm 显示进度条 + train_loader_tqdm = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}]') + + for inputs, labels in train_loader_tqdm: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() @@ -35,6 +39,9 @@ def train(model, train_loader, test_loader, num_epochs=50): _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() + + # 更新进度条的描述信息 + train_loader_tqdm.set_postfix(loss=loss.item(), acc=100. * correct / total) train_loss = running_loss / len(train_loader) train_acc = 100. * correct / total @@ -81,6 +88,4 @@ if __name__ == '__main__': model = MultiResRibNet() # 训练模型 - train(model, train_loader, test_loader) - - + train(model, train_loader, test_loader) \ No newline at end of file