添加了tqdm显示进度条
This commit is contained in:
parent
67c38df0d5
commit
b8468e955b
13
src/train.py
13
src/train.py
@ -4,6 +4,7 @@ import torch.optim as optim
|
||||
from model import MultiResRibNet
|
||||
from dataset import prepare_data
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
def train(model, train_loader, test_loader, num_epochs=50):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@ -22,7 +23,10 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for inputs, labels in train_loader:
|
||||
# 使用 tqdm 显示进度条
|
||||
train_loader_tqdm = tqdm(train_loader, desc=f'Epoch [{epoch+1}/{num_epochs}]')
|
||||
|
||||
for inputs, labels in train_loader_tqdm:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
@ -35,6 +39,9 @@ def train(model, train_loader, test_loader, num_epochs=50):
|
||||
_, predicted = outputs.max(1)
|
||||
total += labels.size(0)
|
||||
correct += predicted.eq(labels).sum().item()
|
||||
|
||||
# 更新进度条的描述信息
|
||||
train_loader_tqdm.set_postfix(loss=loss.item(), acc=100. * correct / total)
|
||||
|
||||
train_loss = running_loss / len(train_loader)
|
||||
train_acc = 100. * correct / total
|
||||
@ -81,6 +88,4 @@ if __name__ == '__main__':
|
||||
model = MultiResRibNet()
|
||||
|
||||
# 训练模型
|
||||
train(model, train_loader, test_loader)
|
||||
|
||||
|
||||
train(model, train_loader, test_loader)
|
||||
Loading…
Reference in New Issue
Block a user