Loading a model which is saved using Dataparallel

I have my model which uses DataParallel whose checkpoint is saved as below

model_single = MyModel()
model = nn.DataParallel(model_single)
model = model.to(device)
torch.save(model.state_dict(), checkpoint_path)

If I try to load this model using the below code, I get the error

model = MyModel()
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint) #ERROR HERE
RuntimeError: Error(s) in loading state_dict for Glow:
	Missing key(s) in state_dict: "blocks.0.flows.0.actnorm.loc"..., "blocks.3.prior.conv.weight", "blocks.3.prior.conv.bias".

	Unexpected key(s) in state_dict: "module.blocks.0.flows.0.actnorm.loc...,""

My model initialization parameters are the same as they were during training. I have searched for solutions in pytorch discussion which mostly suggests to save the model using model.module.state_dict() instead. Is there any other way I could load the model without having to train it again and save it differently?

Please let me know if you require my complete code for better analysis. Thanks in advance.

Seems like you have an extra ‘module.’ in your saved model.

in the example above you are saving DataParallel model and loading it into MyModel, the similar issue was discussed here

Thanks for your response. I was able to solve the issue by passing the loaded model through DataParallel again.

model_single = MyModel()
#solution
model = nn.DataParallel(model_single)
model = model.to(device)
#end
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(checkpoint) #ERROR HERE