Not possible to torch.jit.trace a faster r cnn with the cuda device

Hi I tried to run this code on google collab and got an error. I don’t understand why this error occurs while I just dowloaded a faster r cnn model and called the torh.jit.trace function.

import torch
import torchvision

device = "cuda" if torch.cuda.is_available() else "cpu"
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.to(device)
model.eval()

input_t = torch.rand(1,3,224,224).to(device)
module = torch.jit.trace(model,input_t)

module.save("module.pth")

I got this error :

  for i in range(dim)
/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/anchor_utils.py:123: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-3-cfa0d532d4aa> in <module>()
      8 
      9 input_t = torch.rand(1,3,224,224).to(device)
---> 10 module = torch.jit.trace(model,input_t)
     11 
     12 module.save("module.pth")

9 frames

/usr/local/lib/python3.7/dist-packages/torchvision/models/detection/rpn.py in _get_top_n_idx(self, objectness, num_anchors_per_level)
    223                 pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
    224             _, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
--> 225             r.append(top_n_idx + offset)
    226             offset += num_anchors
    227         return torch.cat(r, dim=1)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Please does someone have a solution ?

Hi @ptrblck sorry to mention but I still didn’t find any solution do you have an idea ?

I’m not sure what exactly is causing the error (I can reproduce it), so would recommend to create an issue in the torchvision repository. As a workaround you could torch.jit.script the model, which should work.