Expected device cuda:0 but got device cpu while converting to onnx

Hi I am trying to convert resnet18_fpn fasterrcnn model to onnx using the code shown below:

model = get_model_instance("resnet18_fpn",15).eval()
checkpoint_file = "/home/sharkspotter/faster-rcnn-pytorch/checkpoints/drone_resnet18_fpn_optimizer_sgd_classes_15_epoch_30.pth"
checkpoint = torch.load(checkpoint_file)
model.load_state_dict(checkpoint['state_dict'])
model = model.cuda()
for name, param in model.named_parameters():
    print (name,param.size())
print ('**********************************************************************************')

x=torch.randn(1,3,224,224,device='cuda',requires_grad=False)
torch.onnx.export(model,x,'faster_rcnn.onnx',export_params= True ,do_constant_folding=False,verbose=True,opset_version = 11)

It is throwing following error.

Traceback (most recent call last):
  File "test.py", line 300, in <module>
    torch.onnx.export(model,x,'faster_rcnn.onnx',export_params= True ,do_constant_folding=False,verbose=True,opset_version = 11)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/onnx/__init__.py", line 148, in export
    strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/onnx/utils.py", line 66, in export
    dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/onnx/utils.py", line 416, in _export
    fixed_batch_size=fixed_batch_size)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/onnx/utils.py", line 279, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/onnx/utils.py", line 236, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(model, args, _force_outplace=True, _return_inputs_states=True)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 277, in _get_trace_graph
    outs = ONNXTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 532, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 360, in forward
    self._force_outplace,
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/jit/__init__.py", line 347, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 530, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 516, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torchvision/models/detection/generalized_rcnn.py", line 70, in forward
    proposals, proposal_losses = self.rpn(images, features, targets)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 530, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py", line 516, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torchvision/models/detection/rpn.py", line 472, in forward
    boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torchvision/models/detection/rpn.py", line 379, in filter_proposals
    top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
  File "/home/sharkspotter/anaconda3/envs/pytorch/lib/python3.6/site-packages/torchvision/models/detection/rpn.py", line 359, in _get_top_n_idx
    r.append(top_n_idx + offset)
RuntimeError: expected device cuda:0 but got device cpu

Could you please help me with this?

You should specify a device model such as cuda: 0

Could you check, if all parameters are successfully loaded onto the device via:

for param in model.parameters():
    print(param.device)

Tried this also, not working.

Yes, I have checked that parameters are loaded successfully onto the device:

for param in model.parameters():
    print(param.device)

cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
cuda:0
........

try:

x = Variable(torch.randn(1, 3, 224, 224)).cuda()
torch.onnx.export(model, x, "fast_rcnn.onnx", verbose=True,···)

Very late response, but I fixed this with a bit of a hack:

from torchvision.models.detection import rpn

old_onnx_get_num_anchors_and_pre_nms_top_n = rpn._onnx_get_num_anchors_and_pre_nms_top_n  # pylint: disable = protected-access

def _onnx_get_num_anchors_and_pre_nms_top_n_fixed(*args, **kwargs):
    num_anchors, pre_nms_top_n = old_onnx_get_num_anchors_and_pre_nms_top_n(*args, **kwargs)
    if not isinstance(num_anchors, int):
        num_anchors = num_anchors.item()
    return num_anchors, pre_nms_top_n

rpn._onnx_get_num_anchors_and_pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n_fixed

Put this in the top of your file. It seems num_anchors is incorrectly returned as a cpu tensor (should just be an int) during onnx.export.

Edit:
On further consideration, I think the reason why it is converted to a tensor is to be sure the trace works, as it could end up converting the int to a constant when converting to onnx (it seemed to work fine for me with my posted solution, but could be a possibility). Perhaps the safer move is to replace num_anchors = num_anchors.item() with num_anchors = num_anchors.to(DEVICE) where DEVICE is given outside the scope of the function. Or for simplicity: num_anchors = num_anchors.cuda() if you know you will run it on cuda.

2 Likes