While trying to load a checkpoint into a resnet model I get this error !
What is wrong here?
this is the snippet that causes this error :
def _init_model(self):
if self.device == 'cpu':
checkpoint = torch.load(self.model_checkpoint_path, map_location=torch.device('cpu'))
else:
checkpoint = torch.load(self.model_checkpoint_path)
if self.model_name == 'r18':
self.model = resnet18(pretrained=False, use_se=False)
elif self.model_name == 'r50':
self.model = resnet50(pretrained=False, use_se=False)
elif self.model_name == 'r101':
self.model = resnet101(pretrained=False, use_se=False)
else:
raise Exception(f"Model name: '{self.model_name}' is not recognized.")
# load the model weights
self.model.load_state_dict(checkpoint['model'])
self.model = self.model.to(self.device)
self.model.eval()
....
and this is the error message I get :
Mar 17 20:47:27 ubuntu python3[20862]: File "/home/user1/anaconda3/lib/python3.7/site-packages/FV/F_V.py", line 58, in __init__
Mar 17 20:47:27 ubuntu python3[20862]: self._init_model()
Mar 17 20:47:27 ubuntu python3[20862]: File "/home/user1/anaconda3/lib/python3.7/site-packages/FV/F_V.py", line 84, in _init_model
Mar 17 20:47:27 ubuntu python3[20862]: self.model.load_state_dict(checkpoint['model'])
Mar 17 20:47:27 ubuntu python3[20862]: File "/home/user1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 803, in load_state_dict
Mar 17 20:47:27 ubuntu python3[20862]: state_dict = state_dict.copy()
Mar 17 20:47:27 ubuntu python3[20862]: File "/home/user1/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 576, in __getattr__
Mar 17 20:47:27 ubuntu python3[20862]: type(self).__name__, name))
Mar 17 20:47:27 ubuntu python3[20862]: AttributeError: 'DataParallel' object has no attribute 'copy'
I want to know how I can get around this error other than doing sth like this which works:
self.model = checkpoint['model'].module
Thanks a lot in adance