diff --git a/src/train.py b/src/train.py index c462d6a..a7e9797 100644 --- a/src/train.py +++ b/src/train.py @@ -4,6 +4,7 @@ import torch.optim as optim from model import MultiResRibNet from dataset import prepare_data import os +from tqdm import tqdm def train(model, train_loader, test_loader, num_epochs=50): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -22,7 +23,10 @@ def train(model, train_loader, test_loader, num_epochs=50): correct = 0 total = 0 - for inputs, labels in train_loader: + # 使用 tqdm 显示进度条 + train_loader_tqdm = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}]') + + for inputs, labels in train_loader_tqdm: inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() @@ -35,6 +39,9 @@ def train(model, train_loader, test_loader, num_epochs=50): _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() + + # 更新进度条的描述信息 + train_loader_tqdm.set_postfix(loss=loss.item(), acc=100. * correct / total) train_loss = running_loss / len(train_loader) train_acc = 100. * correct / total @@ -81,6 +88,4 @@ if __name__ == '__main__': model = MultiResRibNet() # 训练模型 - train(model, train_loader, test_loader) - - + train(model, train_loader, test_loader) \ No newline at end of file