How can I save best model state_dict during training?

I want to get the best model state_dict as in the following code snippet:

def finetune(model, train_loader, val_loader, num_epochs, max_num_epochs, ori_acc, acc_drop_threshold):

    _, best_top1_acc, _ = validate(val_loader, model, criterion)
    best_model_state = model.state_dict()
    epoch = 0
    while epoch < max_num_epochs:
        train(epoch, train_loader, model, criterion, optimizer, scheduler)
        _, valid_top1_acc, _ = validate(val_loader, model, criterion)

        if valid_top1_acc > best_top1_acc:
            best_top1_acc = valid_top1_acc
            best_model_state = model.state_dict()

        acc_drop = ori_acc - best_top1_acc
        if (acc_drop < acc_drop_threshold) and (epoch > num_epochs-1):

        epoch += 1


    return model, epoch

I think that with the above code, the model always loads the best state dict (where it reached the best accuracy) before coming out of the function. But the problem is that when I test the code, the function always returns the model with the last state it got during the while loop.
I don’t want to save the model to disk and then reload it because this’s just a small part of the program.
Thank you. Any recommendation is appreciated!


The .state_dict() method does not copy the parameters but returns a view into the ones in the model.
So if you want to get an independent version (that will not be updated inplace by training), you need to deepcopy it: best_model_state_dict = copy.deepcopy(model.state_dict())

1 Like