Error while loading state_dict

When I am trying to load a model :

Load the trained model from file

trained_model = Net() #Net is myModelClass

m = torch.load(’./modell.pth’, map_location=lambda storage, loc: storage)

trained_model.load_state_dict(m)

trained_model.load_state_dict(m)
File “/home/surya/myenv/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 751, in load_state_dict
state_dict = state_dict.copy()
File “/home/surya/myenv/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 539, in getattr
type(self).name, name))
AttributeError: ‘Net’ object has no attribute ‘copy’

Next I have used copy.copy(state_dict). Then I got a new error.

trained_model.load_state_dict(m)
File “/home/surya/myenv/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 763, in load_state_dict
load(self)
File “/home/surya/myenv/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 758, in load
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
File “/home/surya/myenv/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 720, in _load_from_state_dict
for key in state_dict.keys():
File “/home/surya/myenv/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 539, in getattr
type(self).name, name))
AttributeError: ‘Net’ object has no attribute ‘keys’

My guess is, you used torch.save(model) (saves the entire model with architecture and weights and probably other parameters) instead of torch.save(model.state_dict()) (saves only the state dict, i.e. the weights).

If you need to stick to the first approach, then you must load it as follows: model = torch.load(PATH).

More info here.

9 Likes

What if we saved a torch.save(model) but now need to load the state dict because torch.load(modelpath) code breaks as warned in official documentation ?

1 Like

Do newer versions of PyTorch block you from loading a model directly? In that case, load the model directly from an older version of PyTorch, and then save its state_dict, which you should be able to load with a more recent version.

1 Like

Thanks, I will double check the pytorch version. It most likely is the same though. The problem seems to be something else. It gave a size mismatch error.

thank you very much, i just ignore this problem after i modify my save way