Load a saved model which is saved by torch.save([model, criterion, optimizer], f)

Hi,

I trained a model using an existing code. The model is saved like following during the training:

def model_save(fn):
    with open(fn, 'wb') as f:
        torch.save([model, criterion, optimizer], f)

During inference the original code loads the model like this:

    with open(path, 'rb') as f:
        model = torch.load(f, map_location=lambda storage, loc: storage)

    model.eval()

However since model is saved in a list, model.eval() will threw error.
I used the first element of the list which is supposed to be model itself but I am not sure if the model state_dict() is loaded at all.

You could try to manually check, if all parameters were loaded appropriately.
The recommended way is to store the state_dicts instead of the objects directly, as your approach might break in various ways (described here).

I see.
besides the breaking issue, the loading code is fine?
I mean we dont need the 2nd and third element of torch.load output, since they are criterion and optimizer, right? all parameters should be in model?

All parameters should be loaded properly, if you don’t see any error message.
The optimizer’s state_dict might be necessary to continue your training, if your optimizer uses internal states, e.g. Adam uses running estimates of the parameters.
If you don’t restore the optimizer, you might see a loss spike when continuing the training.

1 Like

Thanks for the explanation.
Actually I trained the model again from scratch and saved the state_dict like torch.save(model.state_dict(), fn) and load the state_dict accordingly. However I am getting this error when I try to test the model:

Traceback (most recent call last):
  File "pytorch_src/test_evaluate.py", line 144, in <module>
    model_load(args.save)
  File "pytorch_src/test_evaluate.py", line 109, in model_load
    model.load_state_dict(torch.load(fn))
  File "/home/anaconda3/envs/py36/lib/python3.7/site-packages/torch/nn/modules/module.py", line 839, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for RNNModel:
        Unexpected key(s) in state_dict: "rnns.0.module.weight_hh_l0", "rnns.1.module.weight_hh_l0", "rnns.2.module.weight_hh_l0".

I guess there must be a mismatch between model and the state_dict keys but not sure how that happened.
‘rnns.0.module.’ prefix get added to original keys I guess. Do you have any idea?

Was the original model saved with DataParallel activated? If so, that explains the module prefix occuring in your state_dict. See my previous answer here for details on how to fix this (in both saving and loading situations).

1 Like