How to load the optimizer for multi-GPU scenario?

Hi, as far as I know the correct way to build a model is:

model = Model() #build the model
model = nn.DataParallel(model)
model.to(device) #move the model to device 
optimizer = optim.Adam(model.parameters()) #build the optimizer

Now assume I want to load the parameters of the model and optimizer states from a pre-trained model (continue learning procedure) for a multi-GPU case. Then I am not sure where to load the optimizer:

model = Model() #build the model on cpu
checkpoint = torch.load(pretrainedModel) # load the pre-trained model
model.load_state_dict(checkpoint['model'])
model = nn.DataParallel(model)
model.to(device) #move the model to device 
optimizer = optim.Adam(model.parameters()) #build the optimizer

It’s pretty much the same, you would call:

optimizer.load_state_dict(checkpoint['optim'])

Of course it requires having saved the optimizer state dict previously.

Also you should be aware that saving models and optimizers that were wrapped nn.DataParallel can result in errors when loading, because the wrapper adds a layer of abstraction, and weights will look like model.module.conv1 instead of model.conv1, for example.

This answer can help with the loading of the model (or you could call nn.DataParallel on your model before loading, but that only works when you actually want nn.DataParallel in your code).

This comment can help when saving nn.DataParallel models, so that you don’t actually need the first solution. However I would recommend using both, so that you can handle every situation.

Thank you. I have tried the following code snippet for the model according to the link you proposed and another comment:

model = Model() #build the model
model = nn.DataParallel(model)
model.to(device) #move the model to device 
optimizer = optim.Adam(model.parameters()) #build the optimizer
try:
    model.load_state_dict(checkpoint['model'])
except RuntimeError:
    model.module.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])