`torch.load` does not map to GPU as advertised

I cannot get torch.load and map_location to work as expected. I have tried three of the suggested methods for loading a model onto the GPU using map_location (from references listed below). The model always ends up on the CPU, despite the documentation seeming to indicate that map_location can load tensors directly onto the GPU. (It even says this should happen by default on a machine with GPUs) Here is a minimal working example, followed by further questions.

import torch
import torch.nn as nn
import torch.nn.functional as F

gpu = torch.device("cuda:0")
print(gpu)

cuda:0

# Build a simple model on the GPU
class Net(nn.Module):
    def __init__(self, indim):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(indim, 5)
        self.fc2 = nn.Linear(5, 1)
        
    def forward(self, t):
        t = F.relu(self.fc1(t))
        t = self.fc2(t)
        return t

    
net = Net(3).to(gpu)

# Force weights to zero so we can later confirm 
# we've loaded the right model
with torch.no_grad():
    net.fc1.weight.fill_(0)
    
print(net.fc1.weight)

Parameter containing:
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], device=‘cuda:0’, requires_grad=True)

Notice my model is on cuda:0

# Save and delete the model
torch.save(net.state_dict(), 'my_gpu_mod.pth')
del(net)

Now we try three different ways to load and map_location to the GPU.

# Load with map_location=device
model = Net(3)
print(model.fc1.weight.max())

model.load_state_dict(torch.load("my_gpu_mod.pth",
                                map_location=gpu))
    
print(model.fc1.weight.device)
print(model.fc1.weight.max())

tensor(0.5071, grad_fn=)
cpu
tensor(0., grad_fn=)

# Load with map_location=string
model = Net(3)
print(model.fc1.weight.max())

model.load_state_dict(torch.load("my_gpu_mod.pth",
                                map_location="cuda:0"))
    
print(model.fc1.weight.device)
print(model.fc1.weight.max())

tensor(0.4118, grad_fn=)
cpu
tensor(0., grad_fn=)

# Load with map_location=lambda
model = Net(3)
print(model.fc1.weight.max())

model.load_state_dict(torch.load('my_gpu_mod.pth',
                      map_location=lambda storage, loc: storage.cuda(0)))
    
print(model.fc1.weight.device)
print(model.fc1.weight.max())

tensor(0.5070, grad_fn=)
cpu
tensor(0., grad_fn=)

If map_location does not automatically put things on the GPU, why does it need to be used at all when loading a GPU-trained model on the same machine that trained it? I can simply call .to afterwards.

References

  1. Saving and Loading Models — PyTorch Tutorials 2.1.1+cu121 documentation
  2. On a cpu device, how to load checkpoint saved on gpu device - #4 by apaszke
  3. torch — PyTorch 2.1 documentation

Could you try to run the code without any available GPUs via export CUDA_VISIBLE_DEVICES= before running the script?
I assume the load method is smart enough to copy the device state_dict back to your CPU model.

EDIT: This issue doesn’t seem to be reproducible using the latest nightly binary, so you might just update. :wink:

(1) In fact, I did run without CUDA_VISIBLE_DEVICES. The machine I’m using only has one GPU and when I run without setting CUDA_VISIBLE_DEVICES, the machine automatically knows to use the only GPU as cuda:0.

(2) I’m not sure I understand this quote

I assume the load method is smart enough to copy the device state_dict back to your CPU model.

Actually I am trying to get the map_location argument to load the model from disk, trained on the GPU, directly into the GPU’s VRAM (as advertised in the documentation). Without having to go through CPU. So I thought it should not end up on the CPU.

(3) Thank you for the suggestion. Since my goal is just to understand and I’m wary of nightly binaries, I will pass on them for now. But to clarify, you are saying, in the newest version, my code above automatically loads the model to the GPU’s VRAM without first loading into CPU memory then transferring to the GPU?

Thanks.

I have made some progress on this issue, by using on torch.cuda.memory_summary rather than mytensor.device. It turns out that using map_location with options that specify the GPU reserves GPU memory but then stores the model.device as “cpu”. In particular, it seems that memory is stored on the GPU and then immediately freed. However, when you use the default map_location, it does not reserve GPU memory at all, GPU memory is never allocated and thus never freed (at least until you call mytensor.to()).

By tracking process CPU memory with psutil, I found that using map_location to GPU causes the total CPU process memory to spike immediately after calling torch.load. However, if you use the map_location default, the total CPU memory does not increase until after you call mytensor.to(). In particular, both approaches (with/without map_location and then to()) end up using about the same process CPU memory, but using map_location to GPU gets to that high number after torch.load but map_location to CPU doesn’t get there until after .to.

I’m left wondering whether this was intended behavior.

I still think that you are hunting down two separate issues.
torch.load(..., map_location) will load the tensor or state_dict onto the specified device.
However, since you are piping this operation directly to model.load_state_dict, internally most likely param.copy_ will be used, which will then respect the parameters’s device and copy the state_dict parameter to the same device used in the model’s parameter.

Try to separate the calls and check the tensors in the state_dict right after using torch.load.
You should see, that the desired device passed via map_location is used.

If you want to avoid the host and device memory allocation, make sure that the model is on the same device as the state_dict before calling .load_state_dict.

I understand now; you were right about my source of confusion!

In summary, calling torch.load and .load_state_dict() as a one-liner will put the CPU-trained model on the CPU, even if you specify map_location to a GPU device in torch.load. However, if one splits these operations into two lines, then torch.load will put the weights in the dictionary onto the desired device. Then, the second line’s load_state_dict will put the model back onto the CPU, where it was trained from.

The following code and output demonstrate this

Illustration code

import torch
import torch.nn as nn
import torch.nn.functional as F

print(f"PyTorch version: {torch.__version__}\n-------")

gpu = torch.device('cuda:0')
PATH = 'todays_cpu_model.pth'

class Net(nn.Module):
    def __init__(self, indim):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(indim, 16)
        self.fc2 = nn.Linear(16, 1)
        
    def forward(self, t):
        t = F.relu(self.fc1(t))
        t = self.fc2(t)
        return t

print("Calling `torch.load` and `model.load_state_dict()` in one line")
model_1 = Net(4)
model_1.load_state_dict(torch.load(PATH, map_location=gpu))
w_1 = next(model_1.parameters()).detach()
print(f"Device of first layer params: {w_1.device}")


print("------\nCalling `torch.load` and `model.load_state_dict()` in TWO DIFFERENT lines")
model_2 = Net(4)
loaded = torch.load(PATH, map_location=gpu)

# access weight in ordered dictionary
w_2a = loaded["fc1.weight"].detach()
print(f"Device of first layer params after `torch.load`: {w_2a.device}")

# now load the state dict and access weight in model object
model_2.load_state_dict(loaded)
w_2b = model_2.fc1.weight.detach()
print(f"Device of first layer params after `.load_state_dict()`: {w_2b.device}")

# Sanity check
# Confirm that the two models are, in fact, equivalent
tests = list()
for kk in model_1.state_dict():
    tests.append(
        torch.equal(model_1.state_dict()[kk].cpu(),model_2.state_dict()[kk].cpu())
    )
    
print(f"----\nTest that the models are identical: {all(tests)}")

The output:

-------
Calling `torch.load` and `model.load_state_dict()` in one line
Device of first layer params: cpu
------
Calling `torch.load` and `model.load_state_dict()` in TWO DIFFERENT lines
Device of first layer params after `torch.load`: cuda:0
Device of first layer params after `.load_state_dict()`: cpu
----
Test that the models are identical: True