I am trying to load a torchscript on a different device from where it is traced, e.g cuda:3. The problem is the Device is a constant, and it is pinned to the tracing device (cuda:0). The map_location of torch.jit.load
doesn’t help either. Is there a way either a) to modify the traced graph and unpin the device constant or b) parameterize the device while tracing, so it is not pinned to a specific device.
graph(%self.1 : __torch__.mmedit.models.inpaintors.two_stage.TwoStageInpaintor,
%masked_img.1 : Tensor,
%mask.1 : Tensor):
%11 : NoneType = prim::Constant()
%10 : bool = prim::Constant[value=0]() # /mmediting/mmedit/models/inpaintors/two_stage.py:76:0
%43 : Device = prim::Constant[value="cuda:0"]()
%6 : int = prim::Constant[value=6]() # /mmediting/mmedit/models/inpaintors/two_stage.py:76:0
%7 : int = prim::Constant[value=0]() # /mmediting/mmedit/models/inpaintors/two_stage.py:76:0
%17 : int = prim::Constant[value=1]() # /mmediting/mmedit/models/inpaintors/two_stage.py:77:0
%26 : float = prim::Constant[value=1.]() # /opt/conda/lib/python3.7/site-packages/torch/_tensor.py:544:0
%4 : __torch__.mmedit.models.backbones.encoder_decoders.two_stage_encoder_decoder.DeepFillEncoderDecoder = prim::GetAttr[name="generator"](%self.1)
%tmp_ones.1 : Tensor = aten::ones_like(%mask.1, %6, %7, %43, %10, %11) # /mmediting/mmedit/models/inpaintors/two_stage.py:76:0
%16 : Tensor[] = prim::ListConstruct(%masked_img.1, %tmp_ones.1, %mask.1)
%x.1 : Tensor = aten::cat(%16, %17) # /mmediting/mmedit/models/inpaintors/two_stage.py:77:0
%21 : Tensor = prim::CallMethod[name="forward"](%4, %x.1) # :0:0
%23 : Tensor = aten::mul(%21, %mask.1) # /mmediting/mmedit/models/inpaintors/two_stage.py:81:0
%28 : Tensor = aten::rsub(%mask.1, %26, %17) # /opt/conda/lib/python3.7/site-packages/torch/_tensor.py:544:0
%29 : Tensor = aten::mul(%masked_img.1, %28) # /mmediting/mmedit/models/inpaintors/two_stage.py:81:0
%33 : Tensor = aten::add(%23, %29, %17) # /mmediting/mmedit/models/inpaintors/two_stage.py:81:0
return (%33)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0! (when checking arugment for argument tensors in method wrapper__cat)