Torch.jit.trace and devices

Hey I have built a transfer learning model for object detection. I would like to load it on a C++ code so i find a way to do it :

# An instance of your model.
model = torchvision.models.resnet18()

# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)

traced_script_module.save("traced_resnet_model.pt")

So I tried it but I had a issue :

RuntimeError                              Traceback (most recent call last)
<ipython-input-18-bdcffa324305> in <module>
----> 1 traced_script_module = torch.jit.trace(model, torch.rand( 1, 3, 1080, 1920))

~/anaconda3/lib/python3.8/site-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    733 
    734     if isinstance(func, torch.nn.Module):
--> 735         return trace_module(
    736             func,
    737             {"forward": example_inputs},

~/anaconda3/lib/python3.8/site-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    950             example_inputs = make_tuple(example_inputs)
    951 
--> 952             module._c._create_method_from_trace(
    953                 method_name,
    954                 func,

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
   1037                 recording_scopes = False
   1038         try:
-> 1039             result = self.forward(*input, **kwargs)
   1040         finally:
   1041             if recording_scopes:

~/anaconda3/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     91                                      .format(degen_bb, target_idx))
     92 
---> 93         features = self.backbone(images.tensors)
     94         if isinstance(features, torch.Tensor):
     95             features = OrderedDict([('0', features)])

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
   1037                 recording_scopes = False
   1038         try:
-> 1039             result = self.forward(*input, **kwargs)
   1040         finally:
   1041             if recording_scopes:

~/anaconda3/lib/python3.8/site-packages/torchvision/models/detection/backbone_utils.py in forward(self, x)
     42 
     43     def forward(self, x):
---> 44         x = self.body(x)
     45         x = self.fpn(x)
     46         return x

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
   1037                 recording_scopes = False
   1038         try:
-> 1039             result = self.forward(*input, **kwargs)
   1040         finally:
   1041             if recording_scopes:

~/anaconda3/lib/python3.8/site-packages/torchvision/models/_utils.py in forward(self, x)
     60         out = OrderedDict()
     61         for name, module in self.items():
---> 62             x = module(x)
     63             if name in self.return_layers:
     64                 out_name = self.return_layers[name]

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
   1037                 recording_scopes = False
   1038         try:
-> 1039             result = self.forward(*input, **kwargs)
   1040         finally:
   1041             if recording_scopes:

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in forward(self, input)
    441 
    442     def forward(self, input: Tensor) -> Tensor:
--> 443         return self._conv_forward(input, self.weight, self.bias)
    444 
    445 class Conv3d(_ConvNd):

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight, bias)
    437                             weight, bias, self.stride,
    438                             _pair(0), self.dilation, self.groups)
--> 439         return F.conv2d(input, weight, bias, self.stride,
    440                         self.padding, self.dilation, self.groups)
    441 

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking arugment for argument weight in method wrapper_thnn_conv2d_forward)

But when define my model, i initially do a model.to(device). And the model works, I can evaluate it.

So I tried to add again .to(device) but I still have a very similar error :

traced_script_module = torch.jit.trace(model.to(device), torch.rand( 1, 3, 1080, 1920).to(device))

the error is :

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
/home/nathaneberrebi/anaconda3/lib/python3.8/site-packages/torchvision/models/detection/anchor_utils.py:123: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-19-d7fcc86bd9ec> in <module>
----> 1 traced_script_module = torch.jit.trace(model.to(device), torch.rand( 1, 3, 1080, 1920).to(device))

~/anaconda3/lib/python3.8/site-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    733 
    734     if isinstance(func, torch.nn.Module):
--> 735         return trace_module(
    736             func,
    737             {"forward": example_inputs},

~/anaconda3/lib/python3.8/site-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    950             example_inputs = make_tuple(example_inputs)
    951 
--> 952             module._c._create_method_from_trace(
    953                 method_name,
    954                 func,

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
   1037                 recording_scopes = False
   1038         try:
-> 1039             result = self.forward(*input, **kwargs)
   1040         finally:
   1041             if recording_scopes:

~/anaconda3/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     94         if isinstance(features, torch.Tensor):
     95             features = OrderedDict([('0', features)])
---> 96         proposals, proposal_losses = self.rpn(images, features, targets)
     97         detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
     98         detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
   1037                 recording_scopes = False
   1038         try:
-> 1039             result = self.forward(*input, **kwargs)
   1040         finally:
   1041             if recording_scopes:

~/anaconda3/lib/python3.8/site-packages/torchvision/models/detection/rpn.py in forward(self, images, features, targets)
    354         proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
    355         proposals = proposals.view(num_images, -1, 4)
--> 356         boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
    357 
    358         losses = {}

~/anaconda3/lib/python3.8/site-packages/torchvision/models/detection/rpn.py in filter_proposals(self, proposals, objectness, image_shapes, num_anchors_per_level)
    243 
    244         # select top_n boxes independently per level before applying nms
--> 245         top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
    246 
    247         image_range = torch.arange(num_images, device=device)

~/anaconda3/lib/python3.8/site-packages/torchvision/models/detection/rpn.py in _get_top_n_idx(self, objectness, num_anchors_per_level)
    223                 pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
    224             _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
--> 225             r.append(top_n_idx + offset)
    226             offset += num_anchors
    227         return torch.cat(r, dim=1)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Thank for your help !

I cannot reproduce the issue using the posted code and just get a NamedTensor warning:

UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  ../c10/core/TensorImpl.h:1257.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)

The traced model is successfully stored in the current working directory.