It’s pretty much the same, you would call:
optimizer.load_state_dict(checkpoint['optim'])
Of course it requires having saved the optimizer state dict previously.
Also you should be aware that saving models and optimizers that were wrapped nn.DataParallel
can result in errors when loading, because the wrapper adds a layer of abstraction, and weights will look like model.module.conv1
instead of model.conv1
, for example.
This answer can help with the loading of the model (or you could call nn.DataParallel
on your model before loading, but that only works when you actually want nn.DataParallel
in your code).
This comment can help when saving nn.DataParallel
models, so that you don’t actually need the first solution. However I would recommend using both, so that you can handle every situation.