Device pinned to tracing device in torch.jit.trace for torchscript

I am trying to load a torchscript on a different device from where it is traced, e.g cuda:3. The problem is the Device is a constant, and it is pinned to the tracing device (cuda:0). The map_location of torch.jit.load doesn’t help either. Is there a way either a) to modify the traced graph and unpin the device constant or b) parameterize the device while tracing, so it is not pinned to a specific device.

graph(%self.1 : __torch__.mmedit.models.inpaintors.two_stage.TwoStageInpaintor,
      %masked_img.1 : Tensor,
      %mask.1 : Tensor):
  %11 : NoneType = prim::Constant()
  %10 : bool = prim::Constant[value=0]() # /mmediting/mmedit/models/inpaintors/
  %43 : Device = prim::Constant[value="cuda:0"]()
  %6 : int = prim::Constant[value=6]() # /mmediting/mmedit/models/inpaintors/
  %7 : int = prim::Constant[value=0]() # /mmediting/mmedit/models/inpaintors/
  %17 : int = prim::Constant[value=1]() # /mmediting/mmedit/models/inpaintors/
  %26 : float = prim::Constant[value=1.]() # /opt/conda/lib/python3.7/site-packages/torch/
  %4 : __torch__.mmedit.models.backbones.encoder_decoders.two_stage_encoder_decoder.DeepFillEncoderDecoder = prim::GetAttr[name="generator"](%self.1)
  %tmp_ones.1 : Tensor = aten::ones_like(%mask.1, %6, %7, %43, %10, %11) # /mmediting/mmedit/models/inpaintors/
  %16 : Tensor[] = prim::ListConstruct(%masked_img.1, %tmp_ones.1, %mask.1)
  %x.1 : Tensor = aten::cat(%16, %17) # /mmediting/mmedit/models/inpaintors/
  %21 : Tensor = prim::CallMethod[name="forward"](%4, %x.1) # :0:0
  %23 : Tensor = aten::mul(%21, %mask.1) # /mmediting/mmedit/models/inpaintors/
  %28 : Tensor = aten::rsub(%mask.1, %26, %17) # /opt/conda/lib/python3.7/site-packages/torch/
  %29 : Tensor = aten::mul(%masked_img.1, %28) # /mmediting/mmedit/models/inpaintors/
  %33 : Tensor = aten::add(%23, %29, %17) # /mmediting/mmedit/models/inpaintors/
  return (%33)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking arugment for argument tensors in method wrapper__cat)

Maybe you can try the below given script and check:

import torch
import io

# Load ScriptModule from io.BytesIO object
with open('', 'rb') as f:
    buffer = io.BytesIO(

# Load all tensors to the original device

# Load all tensors onto CPU, using a device
torch.jit.load(buffer, map_location=torch.device('cpu'))

Hopefully this should work.

I know very little about torch trace.
My guess is you are using ops which trace the device. For example, slicing tensor[1:7] keeps track of the device. You should replace such operators by pytorch ones.

@azhanmohammed Thanks for the reply!

The code snippet is doing the same as this torch.load('', map_location=torch.device('cpu'))
and so unfortunately the above error still persists.

Maybe the best way is to modify the graph node i.e this constant %43 : Device = prim::Constant[value="cuda:0"]() but I haven’t figured out a neat way to do it without it breaking

@JuanFMontesinos could you explain which tensor to slice?

The slicing is just an example of an of carried out in python that leads to a hardcoded tensor in jit.
In your case it seems you are defining a constant.

  %10 : bool = prim::Constant[value=0]() # /mmediting/mmedit/models/inpaintors/
  %43 : Device = prim::Constant[value="cuda:0"]()

I don’t really know how it was created but maybe you can just make a pytorch buffer for it and it will work.