I am using a sharded model (part of model on gpu 0 and part of model on gpu 1). (As suggested here Sharding model across GPUs - #2 by ajdroid )
Since i need to stop and re-start training sometimes, i save both the state_dict of the model and the optimizer to disk. The problem arises when i restore the model and the optimizer.
This is the code that loads the states
state = torch.load(path, map_location='cpu') model.load_state_dict(state["model_state_dict"]) optimizer.load_state_dict(state["optimizer_state_dict"])
Afterwards i run a function that moves the child modules of my model to the approriate gpus:
Previously i relied on the following code to move the optimizers parameters to the gpu
for param in optim.state.values(): if isinstance(param, torch.Tensor): param.data = param.data.to(device) if param._grad is not None: param._grad.data = param._grad.data.to(device) elif isinstance(param, dict): for subparam in param.values(): if isinstance(subparam, torch.Tensor): subparam.data = subparam.data.to(device) if subparam._grad is not None: subparam._grad.data = subparam._grad.data.to(device)
This obviously failed with an error that elements are on different cuda devices.
Afterwards i tried to leave out the part where the optimizers parameters where moved to gpu and my program started working again.
- Am i missing something? When i do
state = torch.load(path, map_location='cpu')i load the tensors to cpu, so how can the parameters of the optimizer can now be located on the right device?
- When i understood something wrong: What is the appropriate way to handle optimizer state when checkpointing a model, both in the un-sharded and sharded scenario?
For reference: I am using Adam optimizer
Thanks in advance