Hi,
I want to get the best model state_dict as in the following code snippet:
def finetune(model, train_loader, val_loader, num_epochs, max_num_epochs, ori_acc, acc_drop_threshold):
_, best_top1_acc, _ = validate(val_loader, model, criterion)
best_model_state = model.state_dict()
epoch = 0
while epoch < max_num_epochs:
train(epoch, train_loader, model, criterion, optimizer, scheduler)
_, valid_top1_acc, _ = validate(val_loader, model, criterion)
if valid_top1_acc > best_top1_acc:
best_top1_acc = valid_top1_acc
best_model_state = model.state_dict()
acc_drop = ori_acc - best_top1_acc
if (acc_drop < acc_drop_threshold) and (epoch > num_epochs-1):
break
epoch += 1
model.load_state_dict(best_model_state)
return model, epoch
I think that with the above code, the model always loads the best state dict (where it reached the best accuracy) before coming out of the function. But the problem is that when I test the code, the function always returns the model with the last state it got during the while loop.
I don’t want to save the model to disk and then reload it because this’s just a small part of the program.
Thank you. Any recommendation is appreciated!