from tqdm.auto import tqdm
随机种子
torch.manual_seed(42)
循环次数
epoch = 5
for epoch in tqdm(range(epoch)):
print(f’第{epoch}次…')
# 统计训练损失
train_loss = 0
for _,(X,y) in enumerate(train_data_lodael):
# 开始训练
mode_l.train()
y_pred = mode_l(X)
# 计算损失
loss = loss_fn(y_pred,y)
train_loss += loss
# 优化
optimizer.zero_grad()
# 后传
loss.backward()
# 优化stop
optimizer.step()
if _ % 400 == 0:
print(f'------------------')
train_loss /= len(train_data_lodael)
# 测试
test_loss,test_acc =0,0
mode_l.eval()
with torch.inference_mode():
for X_test,y_test in test_data_lodael:
# 前传
test_y_pred = mode_l(X_test)
# test
test_loss += loss_fn(y_pred,y_test)
test_acc += accuracy_fn(y_true=y_test,y_pred=test_y_pred)
# 统计
test_loss /=len (test_data_lodael)
test_acc /= len(test_data_lodael)
print(f'\n训练损失:{train_loss:.4f}| 测试损失:{test_loss:.2f}__测试准确度:{test_acc}')