Saving and loading optimizers in Distributed Data Parallel situations

As is given here:

torch.nn.DataParallel is a model wrapper that enables parallel GPU utilization. To save a DataParallel model generically, save the model.module.state_dict() . This way, you have the flexibility to load the model any way you want to any device you want.

Considering the discussion just above this, of saving GPU models and loading on CPU etc. , I’m guessing this line refers to the data distributed models being on any one of the available GPUs and the model.module being the underlying module part that somehow will be device agnostic. (Please correct me here)

That being said, what happens to the optimizer internal state variables (mentioned here). They too would be on whichever GPU rank they were saved from. Is this issue just solved by using the map_location argument to torch.load()? If so, why the special treatment for the DistributedDataModel (model.module.state_dict() used while saving)?

That’s correct. If you try to serialize the nn.DataParallel module itself then it contains the list of devices you parallelize for, the dimension to split the input batch on, etc. When you serialize the inner module then none of that is included and is up to you to do again (or not) after you load it.

Any optimizer state will likely include the devices that those state variables live on. You’re correct to say you can use map_location to remap at load time. Alternatively, you can copy the optimizer state to CPU first, then serialize, and then not worry about it at load time. What special treatment exactly are you talking about?

1 Like

Nothing more! My model is training but the error doesn’t seem to be coming down… I was just exploring the possibilities… This clears it up! Thanks :slight_smile:

Glad to help. Good luck figuring it out.