Yes, I used nn.DataParallel. I didn’t understand your second suggestion. Loading the weights file, create a new ordered dict without the module prefix and load it back. (Can you provide an example?)
I was thinking about something like the following:
# original saved file with DataParallel
state_dict = torch.load('myfile.pth.tar')
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict(new_state_dict)
First of all, thanks a lot, by adding a nn.DataParallel temporarily in my network for loading purposes worked. Even I tried your second suggested approach, it worked for me as well. Thanks a lottt
@wasiahmad By adding nn.DataParallel temporarily into your network did you have to have the same number of GPUs available to load the model as when you saved the model?
A related question, given the fact we see that saving DataParallel wrapped model can cause problems when the model_state_dict is loaded into an unwrapped model. Would one recommend to save the “unwrapped” ‘module’ field inside a DataParallel instance instead ?
You may find out your ‘check_point’ got several keys such as ‘state_dict’ etc.
checkpoint = torch.load(resume)
state_dict =checkpoint['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of dataparallel
new_state_dict[name]=v
model.load_state_dict(new_state_dict)
What about nn.DistributedDataParallel, it seems DistributedDataParallel and DataParallel can load each other’s parameters.
Is there an official way to save/load among DDP/DP/None?
Instead of deleting the “module.” string from all the state_dict keys, you can save your model with: torch.save(model.module.state_dict(), path_to_file)
instead of torch.save(model.state_dict(), path_to_file)
that way you don’t get the “module.” string to begin with…
This code that I am using saves the model using torch.save(model)… in this case the model is load using args.pretrained = torch.load(args.pretrained)
when it is a single gpu. model is one of my models MyModelNet(nn.Module), but in the multi gpu case it is nn.DataParallel(MyModelNet(nn.Module))