Torch.onnx.export, RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

when i put model and input tensor convert to cuda device, then export onnx, occur above errors"RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!",my model is complicated, how to solve?

Is your model generally working via torch.jit.script or torch.jit.trace or are you seeing the same issue?
Could you post an executable code snippet to reproduce the issue, please?

I’m encountering a similar issue. I’m attempting to convert a pytorch model to onnx with fp16 precision. I’m using the following command:

    torch.onnx.export(
        model,
        input_tensor,
        onnx_file_path,
        input_names=["input"],
        output_names=["output"],
        export_params=True,
    )

Both model and input_tensor are fp16 and on gpu (model.cuda(), model.half(), etc.). But I still get the following error:

File "/home/{USER}/.local/lib/python3.8/site-packages/torch/onnx/utils.py", line 628, in _model_to_graph
    params_dict = torch._C._jit_pass_onnx_eval_peephole(graph, params_dict)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

When I inspect graph in line 628 of site-packages/torch/onnx/utils.py, I see that there appears to be some onnx::Constant tensors (among a few others) that did not make it onto gpu. See the full graph below (notice lines %39, %40, %44, %45, %50, %51, %55, %56, %61, %62, %66, %67, %72, %73, %77, %78, %79, %80, %81).

graph(%input : Half(1, 3, 32, 32, strides=[3, 1, 96, 3], requires_grad=0, device=cuda:0),
      %features.0.weight : Half(32, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=1, device=cuda:0),
      %features.0.bias : Half(32, strides=[1], requires_grad=1, device=cuda:0),
      %features.1.running_mean : Half(32, strides=[1], requires_grad=0, device=cuda:0),
      %features.1.running_var : Half(32, strides=[1], requires_grad=0, device=cuda:0),
      %features.1.num_batches_tracked : Long(requires_grad=0, device=cuda:0),
      %features.4.weight : Half(32, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=1, device=cuda:0),
      %features.4.bias : Half(32, strides=[1], requires_grad=1, device=cuda:0),
      %features.5.running_mean : Half(32, strides=[1], requires_grad=0, device=cuda:0),
      %features.5.running_var : Half(32, strides=[1], requires_grad=0, device=cuda:0),
      %features.5.num_batches_tracked : Long(requires_grad=0, device=cuda:0),
      %features.9.weight : Half(64, 32, 3, 3, strides=[288, 9, 3, 1], requires_grad=1, device=cuda:0),
      %features.9.bias : Half(64, strides=[1], requires_grad=1, device=cuda:0),
      %features.10.running_mean : Half(64, strides=[1], requires_grad=0, device=cuda:0),
      %features.10.running_var : Half(64, strides=[1], requires_grad=0, device=cuda:0),
      %features.10.num_batches_tracked : Long(requires_grad=0, device=cuda:0),
      %features.13.weight : Half(64, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=1, device=cuda:0),
      %features.13.bias : Half(64, strides=[1], requires_grad=1, device=cuda:0),
      %features.14.running_mean : Half(64, strides=[1], requires_grad=0, device=cuda:0),
      %features.14.running_var : Half(64, strides=[1], requires_grad=0, device=cuda:0),
      %features.14.num_batches_tracked : Long(requires_grad=0, device=cuda:0),
      %features.18.weight : Half(128, 64, 3, 3, strides=[576, 9, 3, 1], requires_grad=1, device=cuda:0),
      %features.18.bias : Half(128, strides=[1], requires_grad=1, device=cuda:0),
      %features.19.running_mean : Half(128, strides=[1], requires_grad=0, device=cuda:0),
      %features.19.running_var : Half(128, strides=[1], requires_grad=0, device=cuda:0),
      %features.19.num_batches_tracked : Long(requires_grad=0, device=cuda:0),
      %features.22.weight : Half(128, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0),
      %features.22.bias : Half(128, strides=[1], requires_grad=1, device=cuda:0),
      %features.23.running_mean : Half(128, strides=[1], requires_grad=0, device=cuda:0),
      %features.23.running_var : Half(128, strides=[1], requires_grad=0, device=cuda:0),
      %features.23.num_batches_tracked : Long(requires_grad=0, device=cuda:0),
      %features.27.weight : Half(256, 128, 3, 3, strides=[1152, 9, 3, 1], requires_grad=1, device=cuda:0),
      %features.27.bias : Half(256, strides=[1], requires_grad=1, device=cuda:0),
      %features.28.running_mean : Half(256, strides=[1], requires_grad=0, device=cuda:0),
      %features.28.running_var : Half(256, strides=[1], requires_grad=0, device=cuda:0),
      %features.28.num_batches_tracked : Long(requires_grad=0, device=cuda:0),
      %classifier.0.weight : Half(10, 256, strides=[256, 1], requires_grad=1, device=cuda:0),
      %classifier.0.bias : Half(10, strides=[1], requires_grad=1, device=cuda:0)):
  %38 : Half(1, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%input, %features.0.weight, %features.0.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:490:0
  %39 : Half(32, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %40 : Half(32, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %41 : Half(1, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=1, device=cuda:0) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%38, %39, %40, %features.1.running_mean, %features.1.running_var) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %42 : Half(1, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=1, device=cuda:0) = onnx::Relu(%41) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1169:0
  %43 : Half(1, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%42, %features.4.weight, %features.4.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:490:0
  %44 : Half(32, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %45 : Half(32, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %46 : Half(1, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=1, device=cuda:0) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%43, %44, %45, %features.5.running_mean, %features.5.running_var) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %47 : Half(1, 32, 32, 32, strides=[32768, 1024, 32, 1], requires_grad=1, device=cuda:0) = onnx::Relu(%46) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1169:0
  %48 : Half(1, 32, 16, 16, strides=[8192, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%47) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:719:0
  %49 : Half(1, 64, 16, 16, strides=[16384, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%48, %features.9.weight, %features.9.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:490:0
  %50 : Half(64, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %51 : Half(64, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %52 : Half(1, 64, 16, 16, strides=[16384, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%49, %50, %51, %features.10.running_mean, %features.10.running_var) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %53 : Half(1, 64, 16, 16, strides=[16384, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::Relu(%52) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1169:0
  %54 : Half(1, 64, 16, 16, strides=[16384, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%53, %features.13.weight, %features.13.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:490:0
  %55 : Half(64, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %56 : Half(64, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %57 : Half(1, 64, 16, 16, strides=[16384, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%54, %55, %56, %features.14.running_mean, %features.14.running_var) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %58 : Half(1, 64, 16, 16, strides=[16384, 256, 16, 1], requires_grad=1, device=cuda:0) = onnx::Relu(%57) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1169:0
  %59 : Half(1, 64, 8, 8, strides=[4096, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%58) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:719:0
  %60 : Half(1, 128, 8, 8, strides=[8192, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%59, %features.18.weight, %features.18.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:490:0
  %61 : Half(128, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %62 : Half(128, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %63 : Half(1, 128, 8, 8, strides=[8192, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%60, %61, %62, %features.19.running_mean, %features.19.running_var) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %64 : Half(1, 128, 8, 8, strides=[8192, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Relu(%63) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1169:0
  %65 : Half(1, 128, 8, 8, strides=[8192, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%64, %features.22.weight, %features.22.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:490:0
  %66 : Half(128, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %67 : Half(128, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %68 : Half(1, 128, 8, 8, strides=[8192, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%65, %66, %67, %features.23.running_mean, %features.23.running_var) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %69 : Half(1, 128, 8, 8, strides=[8192, 64, 8, 1], requires_grad=1, device=cuda:0) = onnx::Relu(%68) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1169:0
  %70 : Half(1, 128, 4, 4, strides=[2048, 16, 4, 1], requires_grad=1, device=cuda:0) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%69) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:719:0
  %71 : Half(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[0, 0, 0, 0], strides=[1, 1]](%70, %features.27.weight, %features.27.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py:490:0
  %72 : Half(256, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %73 : Half(256, strides=[1], device=cpu) = onnx::Constant[value=<Tensor>]()
  %74 : Half(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::BatchNormalization[epsilon=1.0000000000000001e-05, momentum=0.90000000000000002](%71, %72, %73, %features.28.running_mean, %features.28.running_var) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:2282:0
  %75 : Half(1, 256, 2, 2, strides=[1024, 4, 2, 1], requires_grad=1, device=cuda:0) = onnx::Relu(%74) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1169:0
  %76 : Half(1, 256, 1, 1, strides=[256, 1, 1, 1], requires_grad=1, device=cuda:0) = onnx::MaxPool[kernel_shape=[2, 2], pads=[0, 0, 0, 0], strides=[2, 2]](%75) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:719:0
  %77 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={1}]()
  %78 : Long(requires_grad=0, device=cpu) = onnx::Constant[value={-1}]()
  %79 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[axes=[0]](%77)
  %80 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[axes=[0]](%78)
  %81 : Long(2, strides=[1], device=cpu) = onnx::Concat[axis=0](%79, %80)
  %82 : Half(1, 256, strides=[256, 1], requires_grad=1, device=cuda:0) = onnx::Reshape(%76, %81)
  %83 : Half(1, 10, strides=[10, 1], requires_grad=1, device=cuda:0) = onnx::Gemm[alpha=1., beta=1., transB=1](%82, %classifier.0.weight, %classifier.0.bias) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1848:0
  %output : Half(1, 10, strides=[10, 1], requires_grad=1, device=cuda:0) = onnx::Softmax[axis=1](%83) # /home/jordan/.local/lib/python3.8/site-packages/torch/nn/functional.py:1680:0
  return (%output)

So my question is how can I access these tensors in my pytorch model and force them to gpu? I tried messing with the model’s _apply function as described here, but still couldn’t get it to work.

1 Like

I think the issue you are seeing might be created by torch.jit.trace which is used by the ONNX exporter as the default. Could you try to trace the model without ONNX and see if it’s still failing?
If so, try to use torch.jit.script instead and see if this would be working. If so, use it in the ONNX export as well.

I tried to use torch.jit.trace and it was able to successfully trace the model.

traced_module = torch.jit.trace(model, input_tensor) # success

But I still get the same error when I try to use the traced module in the onnx export.

    torch.onnx.export(
        traced_module,
        input_tensor,
        onnx_file_path,
        input_names=["input"],
        output_names=["output"],
        export_params=True,
    ) # fail, same as before

Thanks for the test!
I don’t see any information if jit.script would work in ONNX, but in any case you might want to create an issue so that the code owners could debug the error.

Thanks for the replies. I tried jit.script in ONNX as well, but still got the same error.
I just opened as issue here.

1 Like