Moving optimizer from CPU to GPU

I have a model and an optimizer and I want to save it’s state dict as CPU tensors. Then I want to load those state dicts back on GPU. This seems straightforward to do for a model, but what’s the best way to do this for the optimizer?

This is what my code looks like right now:

model = ...
optim = torch.optim.SGD(model.parameters(), momentum=0.1)

model_state = model.state_dict()
# Convert to CPU
for k, v in model_state.items():
   model_state[k] = v.cpu()

optim_state = optim.state_dict()
# Convert to CPU
for state in optim_state["state"].values():
    for k, v in state.items():
        state[k] = v.cpu()

# Now I want to load these state dicts back onto GPU
model2 = ...
model2.cuda()
optim2 = torch.optim.SGD(model2.parameters(), momentum=0.1)


# This seems to work; the model2 parameters are on GPU
model2.load_state_dict(model_state)

# Same does not hold true for optimizer
optim2.load_state_dict(optim_state)

The only option I see is to manually convert optimizer state back to Cuda

for state in optim2.state.values():
    for k, v in state.items():
        state[k] = v.cuda()

But would optim2 still update model2’s parameters?

2 Likes

I have the same problem.
I have limited GPU memory. I can train with model and optimizer on GPU. However, GPU memory surges when loading model and optimizer to GPU, see https://github.com/pytorch/pytorch/issues/7415
Effect is that I can’t load a previous checkpoint during training directly to GPU without going OOM. For the model, loading to CPU first and then moving to GPU works (see code below).
Now, I go OOM when loading the optimizer. I would like to load the optimizer to CPU first, and then move it to the GPU. How can I do this?

# load model
# OOM triggered when directly loading to GPU, see https://github.com/pytorch/pytorch/issues/7415
# params = torch.load(model_save_path, map_location=lambda storage, loc: storage)
# Instead, load to CPU first
params = torch.load(model_save_path, map_location='cpu')
model.load_state_dict(params['state_dict'])
# And move model to GPU
model = model.to(device)

# Goes OOM - How can I load to CPU and move to GPU?
optimizer.load_state_dict(torch.load(model_save_path + '.optim'))

Hello @amogkam,
Here https://github.com/pytorch/pytorch/issues/8741 is an old feature request for a pytorch fct to move optimizer to device. I use the optimizer_to function posted there and get around my OOM, training looks good so far.

def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

Here’s how I use it

# Load model
params = torch.load(model_save_path, map_location='cpu')
model.load_state_dict(params['state_dict'])
model = model.to(device)

# Empty any cache, not sure this helps, we try waht we can 
torch.cuda.empty_cache()

# Load optimizer
# Load to CPU first
optimizer.load_state_dict(torch.load(model_save_path + '.optim',map_location='cpu'))
# Send to GPU
optimizer_to(optimizer,device)
2 Likes

Best solution for this would be for pytorch to provide similar interface to model.to(device) for the optimizer optim.to(device) as well.

Another solution would have been to not save tensors in the state dicts with the device argument in them so that when loading a model would not result in this discrepancy between model state dict and optim state dict.

e.g. If all tensors in state dicts of both model and optim do not have a device assoctiated with them then simply doing

state_dict = torch.load('some_model.pt')
model.load_state_dict(statedict['model'])
optim.load_state_dict(state_dict['optim'])

would have worked fine .