How to load constant tensors in the model onto GPU?

I am facing a weird issue. I am doing the following

  1. Freeze a torchvision model on CPU host, which internalizes parameters/weight tensors as constants.
  2. Save this frozen module.
  3. Try to load and run the model on GPU host.

I see that the weight tensor constants are not getting loaded onto GPU despite me invoking on the loaded model.

How do we ensure that the tensor constants, of the saved model, get loaded onto GPU?

Error StackTrace:

Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torchvision/models/segmentation/fcn/", line 9, in forward
    input_shape = torch.slice(torch.size(x), -2, 9223372036854775807, 1)
    features = torch.dict()
    x0 = torch.conv2d(x, CONSTANTS.c0, None, [2, 2], [3, 3], [1, 1], 1)
         ~~~~~~~~~~~~ <--- HERE
    x1 = torch.batch_norm(x0, CONSTANTS.c1, CONSTANTS.c2, CONSTANTS.c3, CONSTANTS.c4, False, 0.10000000000000001, 1.0000000000000001e-05, True)
    x2 = torch.relu_(x1)
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

torch.jit.load(...., map_location="cuda") will work on master.

scripted_module = torch.jit.script(torch.nn.Linear(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
assert len(list(frozen_module.named_parameters())) == 0

and then

graph(%self : __torch__.torch.nn.modules.linear.___torch_mangle_0.Linear,
      %input.1 : Tensor):
  %8 : Tensor = prim::Constant[value=-0.6421  0.1414 -0.4491 -0.5532 -0.0829  0.0914 [ CPUFloatType{2,3} ]]() # :0:0
  %6 : Tensor = prim::Constant[value=0.01 * -9.0403 -27.0162  3.8648 [ CPUFloatType{3} ]]() # :0:0
  %4 : int = prim::Constant[value=2]() # /usr/local/lib/python3.9/dist-packages/torch/nn/
  %9 : int = prim::Constant[value=1]() # :0:0
  %3 : int = aten::dim(%input.1) # /usr/local/lib/python3.9/dist-packages/torch/nn/
  %5 : bool = aten::eq(%3, %4) # /usr/local/lib/python3.9/dist-packages/torch/nn/
  %ret : Tensor = prim::If(%5) # /usr/local/lib/python3.9/dist-packages/torch/nn/
      %ret0.1 : Tensor = aten::addmm(%6, %input.1, %8, %9, %9) # /usr/local/lib/python3.9/dist-packages/torch/nn/
      -> (%ret0.1)
      %output.1 : Tensor = aten::matmul(%input.1, %8) # /usr/local/lib/python3.9/dist-packages/torch/nn/
      %output0.1 : Tensor = aten::add_(%output.1, %6, %9) # /usr/local/lib/python3.9/dist-packages/torch/nn/
      -> (%output0.1)
  return (%ret)

In [17]: torch.jit.load('/tmp/', map_location="cuda").graph
graph(%self : __torch__.torch.nn.modules.linear.___torch_mangle_0.Linear,
      %input.1 : Tensor):
  %8 : Tensor = prim::Constant[value=-0.6421  0.1414 -0.4491 -0.5532 -0.0829  0.0914 [ CUDAFloatType{2,3} ]]() # :0:0
  %6 : Tensor = prim::Constant[value=0.01 * -9.0403 -27.0162  3.8648 [ CUDAFloatType{3} ]]() # :0:0
  %4 : int = prim::Constant[value=2]() # /usr/local/lib/python3.9/dist-packages/torch/nn/
  %9 : int = prim::Constant[value=1]() # :0:0
  %3 : int = aten::dim(%input.1) # /usr/local/lib/python3.9/dist-packages/torch/nn/
  %5 : bool = aten::eq(%3, %4) # /usr/local/lib/python3.9/dist-packages/torch/nn/
  %ret : Tensor = prim::If(%5) # /usr/local/lib/python3.9/dist-packages/torch/nn/
      %ret0.1 : Tensor = aten::addmm(%6, %input.1, %8, %9, %9) # /usr/local/lib/python3.9/dist-packages/torch/nn/
      -> (%ret0.1)
      %output.1 : Tensor = aten::matmul(%input.1, %8) # /usr/local/lib/python3.9/dist-packages/torch/nn/
      %output0.1 : Tensor = aten::add_(%output.1, %6, %9) # /usr/local/lib/python3.9/dist-packages/torch/nn/
      -> (%output0.1)
  return (%ret)

Note that the constants move from CPU to CUDA in the second example.

Best regards


Thanks for this quick suggestion. I will try to dig deeper to understand why .to(device) doesn’t do the same.

I can’t look today, but if you remind me in a week I might take a look.

This is quite an old bug, but I’m seeing the same issue. In my case I can’t do map_location because I have 2 models and not enough GPU memory so I have to keep moving them back and forth from CPU to GPU. I realize this is quite an old bug report but I’m hoping somebody has looked into it since this was reported.