Restore optimizer for sharded model

I am using a sharded model (part of model on gpu 0 and part of model on gpu 1). (As suggested here Sharding model across GPUs - #2 by ajdroid )

Since i need to stop and re-start training sometimes, i save both the state_dict of the model and the optimizer to disk. The problem arises when i restore the model and the optimizer.
This is the code that loads the states

state = torch.load(path, map_location='cpu')

Afterwards i run a function that moves the child modules of my model to the approriate gpus:


Previously i relied on the following code to move the optimizers parameters to the gpu

for param in optim.state.values():
        if isinstance(param, torch.Tensor):
            if param._grad is not None:
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    if subparam._grad is not None:

This obviously failed with an error that elements are on different cuda devices.

Afterwards i tried to leave out the part where the optimizers parameters where moved to gpu and my program started working again.

My question:

  1. Am i missing something? When i do state = torch.load(path, map_location='cpu') i load the tensors to cpu, so how can the parameters of the optimizer can now be located on the right device?
  2. When i understood something wrong: What is the appropriate way to handle optimizer state when checkpointing a model, both in the un-sharded and sharded scenario?

For reference: I am using Adam optimizer

Thanks in advance