I have my model which uses DataParallel whose checkpoint is saved as below
model_single = MyModel() model = nn.DataParallel(model_single) model = model.to(device) torch.save(model.state_dict(), checkpoint_path)
If I try to load this model using the below code, I get the error
model = MyModel() checkpoint = torch.load(checkpoint_path, map_location="cpu") model.load_state_dict(checkpoint) #ERROR HERE
RuntimeError: Error(s) in loading state_dict for Glow: Missing key(s) in state_dict: "blocks.0.flows.0.actnorm.loc"..., "blocks.3.prior.conv.weight", "blocks.3.prior.conv.bias". Unexpected key(s) in state_dict: "module.blocks.0.flows.0.actnorm.loc...,""
My model initialization parameters are the same as they were during training. I have searched for solutions in pytorch discussion which mostly suggests to save the model using
model.module.state_dict() instead. Is there any other way I could load the model without having to train it again and save it differently?
Please let me know if you require my complete code for better analysis. Thanks in advance.