How to load constant tensors in the model onto GPU?

I am facing a weird issue. I am doing the following

  1. Freeze a torchvision model on CPU host, which internalizes parameters/weight tensors as constants.
  2. Save this frozen module.
  3. Try to load and run the model on GPU host.

I see that the weight tensor constants are not getting loaded onto GPU despite me invoking model.to(cuda:0) on the loaded model.

How do we ensure that the tensor constants, of the saved model, get loaded onto GPU?

Error StackTrace:

Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/torchvision/models/segmentation/fcn/___torch_mangle_17301.py", line 9, in forward
    input_shape = torch.slice(torch.size(x), -2, 9223372036854775807, 1)
    features = torch.dict()
    x0 = torch.conv2d(x, CONSTANTS.c0, None, [2, 2], [3, 3], [1, 1], 1)
         ~~~~~~~~~~~~ <--- HERE
    x1 = torch.batch_norm(x0, CONSTANTS.c1, CONSTANTS.c2, CONSTANTS.c3, CONSTANTS.c4, False, 0.10000000000000001, 1.0000000000000001e-05, True)
    x2 = torch.relu_(x1)
 
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

torch.jit.load(...., map_location="cuda") will work on master.

scripted_module = torch.jit.script(torch.nn.Linear(2, 3).eval())
frozen_module = torch.jit.freeze(scripted_module)
assert len(list(frozen_module.named_parameters())) == 0
print(frozen_module.code)
frozen_module.save('/tmp/x.pt')

and then

orch.jit.load('/tmp/x.pt').graph
Out[16]: 
graph(%self : __torch__.torch.nn.modules.linear.___torch_mangle_0.Linear,
      %input.1 : Tensor):
  %8 : Tensor = prim::Constant[value=-0.6421  0.1414 -0.4491 -0.5532 -0.0829  0.0914 [ CPUFloatType{2,3} ]]() # :0:0
  %6 : Tensor = prim::Constant[value=0.01 * -9.0403 -27.0162  3.8648 [ CPUFloatType{3} ]]() # :0:0
  %4 : int = prim::Constant[value=2]() # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:22
  %9 : int = prim::Constant[value=1]() # :0:0
  %3 : int = aten::dim(%input.1) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:7
  %5 : bool = aten::eq(%3, %4) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:7
  %ret : Tensor = prim::If(%5) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:4
    block0():
      %ret0.1 : Tensor = aten::addmm(%6, %input.1, %8, %9, %9) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1665:14
      -> (%ret0.1)
    block1():
      %output.1 : Tensor = aten::matmul(%input.1, %8) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1667:17
      %output0.1 : Tensor = aten::add_(%output.1, %6, %9) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1669:12
      -> (%output0.1)
  return (%ret)


In [17]: torch.jit.load('/tmp/x.pt', map_location="cuda").graph
Out[17]: 
graph(%self : __torch__.torch.nn.modules.linear.___torch_mangle_0.Linear,
      %input.1 : Tensor):
  %8 : Tensor = prim::Constant[value=-0.6421  0.1414 -0.4491 -0.5532 -0.0829  0.0914 [ CUDAFloatType{2,3} ]]() # :0:0
  %6 : Tensor = prim::Constant[value=0.01 * -9.0403 -27.0162  3.8648 [ CUDAFloatType{3} ]]() # :0:0
  %4 : int = prim::Constant[value=2]() # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:22
  %9 : int = prim::Constant[value=1]() # :0:0
  %3 : int = aten::dim(%input.1) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:7
  %5 : bool = aten::eq(%3, %4) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:7
  %ret : Tensor = prim::If(%5) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1663:4
    block0():
      %ret0.1 : Tensor = aten::addmm(%6, %input.1, %8, %9, %9) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1665:14
      -> (%ret0.1)
    block1():
      %output.1 : Tensor = aten::matmul(%input.1, %8) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1667:17
      %output0.1 : Tensor = aten::add_(%output.1, %6, %9) # /usr/local/lib/python3.9/dist-packages/torch/nn/functional.py:1669:12
      -> (%output0.1)
  return (%ret)

Note that the constants move from CPU to CUDA in the second example.

Best regards

Thomas

Thanks for this quick suggestion. I will try to dig deeper to understand why .to(device) doesn’t do the same.

I can’t look today, but if you remind me in a week I might take a look.