Unable to load model using load_state_dict

I have trained squeeenet on GPU using nn,Dataparallel, the saved the model

  1. torch.load is working
    torch.load(’./model_ftsqueezenet_morse150.pt’,map_location={‘cuda:0’: ‘cpu’})

  2. but model. load_state_dict gives error
    model.load_state_dict(torch.load(’./model_ftsqueezenet_morse150.pt’,map_location={‘cuda:0’: ‘cpu’}))


RuntimeError Traceback (most recent call last)
in ()
----> 1 model.load_state_dict(torch.load(’./visualize_results/pretraining/models/model_ftsqueezenet_morse150.pt’,map_location={‘cuda:0’: ‘cpu’}))

~/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
757 if len(error_msgs) > 0:
758 raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
–> 759 self.class.name, “\n\t”.join(error_msgs)))
760
761 def _named_members(self, get_members_fn, prefix=’’, recurse=True):

RuntimeError: Error(s) in loading state_dict for SqueezeNet:
Missing key(s) in state_dict: “features.0.weight”, “features.0.bias”, “features.3.squeeze.weight”, “features.3.squeeze.bias”, “features.3.expand1x1.weight”, “features.3.expand1x1.bias”, “features.3.expand3x3.weight”, “features.3.expand3x3.bias”, “features.4.squeeze.weight”, “features.4.squeeze.bias”, “features.4.expand1x1.weight”, “features.4.expand1x1.bias”, “features.4.expand3x3.weight”, “features.4.expand3x3.bias”, “features.5.squeeze.weight”, “features.5.squeeze.bias”, “features.5.expand1x1.weight”, “features.5.expand1x1.bias”, “features.5.expand3x3.weight”, “features.5.expand3x3.bias”, “features.7.squeeze.weight”, “features.7.squeeze.bias”, “features.7.expand1x1.weight”, “features.7.expand1x1.bias”, “features.7.expand3x3.weight”, “features.7.expand3x3.bias”, “features.8.squeeze.weight”, “features.8.squeeze.bias”, “features.8.expand1x1.weight”, “features.8.expand1x1.bias”, “features.8.expand3x3.weight”, “features.8.expand3x3.bias”, “features.9.squeeze.weight”, “features.9.squeeze.bias”, “features.9.expand1x1.weight”, “features.9.expand1x1.bias”, “features.9.expand3x3.weight”, “features.9.expand3x3.bias”, “features.10.squeeze.weight”, “features.10.squeeze.bias”, “features.10.expand1x1.weight”, “features.10.expand1x1.bias”, “features.10.expand3x3.weight”, “features.10.expand3x3.bias”, “features.12.squeeze.weight”, “features.12.squeeze.bias”, “features.12.expand1x1.weight”, “features.12.expand1x1.bias”, “features.12.expand3x3.weight”, “features.12.expand3x3.bias”, “classifier.1.weight”, “classifier.1.bias”.
Unexpected key(s) in state_dict: “module.features.0.weight”, “module.features.0.bias”, “module.features.3.squeeze.weight”, “module.features.3.squeeze.bias”, “module.features.3.expand1x1.weight”, “module.features.3.expand1x1.bias”, “module.features.3.expand3x3.weight”, “module.features.3.expand3x3.bias”, “module.features.4.squeeze.weight”, “module.features.4.squeeze.bias”, “module.features.4.expand1x1.weight”, “module.features.4.expand1x1.bias”, “module.features.4.expand3x3.weight”, “module.features.4.expand3x3.bias”, “module.features.5.squeeze.weight”, “module.features.5.squeeze.bias”, “module.features.5.expand1x1.weight”, “module.features.5.expand1x1.bias”, “module.features.5.expand3x3.weight”, “module.features.5.expand3x3.bias”, “module.features.7.squeeze.weight”, “module.features.7.squeeze.bias”, “module.features.7.expand1x1.weight”, “module.features.7.expand1x1.bias”, “module.features.7.expand3x3.weight”, “module.features.7.expand3x3.bias”, “module.features.8.squeeze.weight”, “module.features.8.squeeze.bias”, “module.features.8.expand1x1.weight”, “module.features.8.expand1x1.bias”, “module.features.8.expand3x3.weight”, “module.features.8.expand3x3.bias”, “module.features.9.squeeze.weight”, “module.features.9.squeeze.bias”, “module.features.9.expand1x1.weight”, “module.features.9.expand1x1.bias”, “module.features.9.expand3x3.weight”, “module.features.9.expand3x3.bias”, “module.features.10.squeeze.weight”, “module.features.10.squeeze.bias”, “module.features.10.expand1x1.weight”, “module.features.10.expand1x1.bias”, “module.features.10.expand3x3.weight”, “module.features.10.expand3x3.bias”, “module.features.12.squeeze.weight”, “module.features.12.squeeze.bias”, “module.features.12.expand1x1.weight”, “module.features.12.expand1x1.bias”, “module.features.12.expand3x3.weight”, “module.features.12.expand3x3.bias”, “module.classifier.1.weight”, “module.classifier.1.bias”.

You are saving the DataParallel module’s state_dict but loading to the non-DP module. If you call load_state_dict on the DataParallel module, it will work. Alternatively, when saving, use state_dict of the dp_module.module.state_dict().