Hello @amogkam,
Here https://github.com/pytorch/pytorch/issues/8741 is an old feature request for a pytorch fct to move optimizer to device. I use the optimizer_to function posted there and get around my OOM, training looks good so far.
def optimizer_to(optim, device):
for param in optim.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
Here’s how I use it
# Load model
params = torch.load(model_save_path, map_location='cpu')
model.load_state_dict(params['state_dict'])
model = model.to(device)
# Empty any cache, not sure this helps, we try waht we can
torch.cuda.empty_cache()
# Load optimizer
# Load to CPU first
optimizer.load_state_dict(torch.load(model_save_path + '.optim',map_location='cpu'))
# Send to GPU
optimizer_to(optimizer,device)