I want to get the best model state_dict as in the following code snippet:
def finetune(model, train_loader, val_loader, num_epochs, max_num_epochs, ori_acc, acc_drop_threshold): _, best_top1_acc, _ = validate(val_loader, model, criterion) best_model_state = model.state_dict() epoch = 0 while epoch < max_num_epochs: train(epoch, train_loader, model, criterion, optimizer, scheduler) _, valid_top1_acc, _ = validate(val_loader, model, criterion) if valid_top1_acc > best_top1_acc: best_top1_acc = valid_top1_acc best_model_state = model.state_dict() acc_drop = ori_acc - best_top1_acc if (acc_drop < acc_drop_threshold) and (epoch > num_epochs-1): break epoch += 1 model.load_state_dict(best_model_state) return model, epoch
I think that with the above code, the model always loads the best state dict (where it reached the best accuracy) before coming out of the function. But the problem is that when I test the code, the function always returns the model with the last state it got during the while loop.
I don’t want to save the model to disk and then reload it because this’s just a small part of the program.
Thank you. Any recommendation is appreciated!