best_accuracy = 0
for epoch in range(config.epochs):
train_fx(train_data_loader, model, optimizer,scheduler, device)
outputs, target = eval_fx(valid_data_loader, model, device)
accuracy = accuracy_metrics(outputs, target)
print(f’Accuracy score ----- {accuracy}')
if accuracy > best_accuracy:
torch.save(model.state_dict(), config.model_path)
best_accuracy = accuracy_metrics
The problem lies here : outputs, target = eval_fx(valid_data_loader, model, device)