Restore optimizer for sharded model

Hi,
I am using a sharded model (part of model on gpu 0 and part of model on gpu 1). (As suggested here Sharding model across GPUs - #2 by ajdroid )

Since i need to stop and re-start training sometimes, i save both the state_dict of the model and the optimizer to disk. The problem arises when i restore the model and the optimizer.
This is the code that loads the states

state = torch.load(path, map_location='cpu')
model.load_state_dict(state["model_state_dict"])
optimizer.load_state_dict(state["optimizer_state_dict"])

Afterwards i run a function that moves the child modules of my model to the approriate gpus:

model.move_to_gpu()  

Previously i relied on the following code to move the optimizers parameters to the gpu

for param in optim.state.values():
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

This obviously failed with an error that elements are on different cuda devices.

Afterwards i tried to leave out the part where the optimizers parameters where moved to gpu and my program started working again.

My question:

  1. Am i missing something? When i do state = torch.load(path, map_location='cpu') i load the tensors to cpu, so how can the parameters of the optimizer can now be located on the right device?
  2. When i understood something wrong: What is the appropriate way to handle optimizer state when checkpointing a model, both in the un-sharded and sharded scenario?

For reference: I am using Adam optimizer

Thanks in advance