Loading specific optimizer tensors ont different devices using torch.load

Hi
I have a neural net model with optimizer state data saved on a pickled file (excuse if my terminology is imprecise) at a checkpoint. The tensors for the model and the optimizer were all saved from the GPU, and when the checkpoint is loaded using torch.load, they are loaded on the GPU again. However, when I load from the checkpoint, I would like some specific optimizer state tensors (e.g., the exponential moving average momentum tensor of the Adam optimizer, for example) to be loaded into CPU instead of GPU. I have found very little documentation on the map_location option in torch.load to explain how to write a function that does this. Can you please provide an example for how to do this for one or more tensors? You can assume that the optimizer tensors to be loaded into CPU are state[‘exp_avg’] and state[‘exp_avg_sq’] for definiteness.
Thanks!

You can pass map_location='cpu' to torch.load to load the object to the CPU.
Here is a small example showing the usage:

# save optimizer.state_dict() with CUDATensors
model = models.resnet18().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1.)
out = model(torch.randn(1, 3, 224, 224).cuda())
out.mean().backward()
optimizer.step()

torch.save(optimizer.state_dict(), 'opt.pth')
optimizer.state_dict()

# in a new script make sure the GPU is not visible and load the state_dict to the CPU
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import torch
import torchvision.models as models

torch.cuda.get_device_name()
# > RuntimeError: No CUDA GPUs are available

model = models.resnet18()
optimizer = torch.optim.Adam(model.parameters(), lr=1.)
optimizer.load_state_dict(torch.load('opt.pth'))
# > RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
optimizer.load_state_dict(torch.load('opt.pth', map_location='cpu')) # works