Can I unify torch jit's input and torch_tensorRT's input?

I am testing converting a simple gcn model from pytorch geometric to tensorRT .

The source code used is:

Converting that model to a jit model was successful.

jit_model = torch.jit.trace(model, (data.x, data.edge_index))
output_jit = jit_model(data.x, data.edge_index)

Failed to compile the model using torch_tensorrt.

code & error

with torch_tensorrt.logging.debug():
    trt_model = torch_tensorrt.compile(model, inputs=[data.x, data.edge_index])


Traceback (most recent call last):
  File "torch_geo2.py", line 83, in <module>
    trt_model = torch_tensorrt.compile(model, inputs=[data.x, data.edge_index])
  File "/opt/anaconda3/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 125, in compile
    return torch_tensorrt.ts.compile(
  File "/opt/anaconda3/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 136, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: Unknown type NoneType encountered in graph lowering. This type is not supported in ONNX export.

While tracing the error, I found that tensorRT doesn’t accept int64 type, so I tried to apply it with int32 type.

However, it has been confirmed that int32 type cannot be used in torch’s tensor indices function.

error

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
  File "torch_geo2.py", line 27, in forward
    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv1(x, edge_index).relu()
            ~~~~~~~~~~ <--- HERE
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
  File "/tmp/root_pyg/tmp7_glyyjg.py", line 226, in forward__0
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                                              ~~~~~~~~ <--- HERE
                        edge_index, edge_weight, x.size(self.node_dim),
                        self.improved, self.add_self_loops, self.flow, x.dtype)
  File "/opt/anaconda3/lib/python3.8/site-packages/torch_geometric/nn/conv/gcn_conv.py", line 61, in gcn_norm

        if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                                          ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                edge_index, edge_weight, fill_value, num_nodes)
            assert tmp_edge_weight is not None
  File "/opt/anaconda3/lib/python3.8/site-packages/torch_geometric/utils/loop.py", line 298, in add_remaining_self_loops

        inv_mask = ~mask
        loop_attr[edge_index[0][inv_mask]] = edge_attr[inv_mask]
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

        edge_attr = torch.cat([edge_attr[mask], loop_attr], dim=0)
RuntimeError: tensors used as indices must be long, byte or bool tensors

What I figured out is that the int32 type of input causes an error in pytorch , and conversely the int64 type causes an error in torch_tensorrt.

Is there any way to solve this ?

Thanks in advance for any help.

CC @narendasan as the TorchTRT expert

What version of Torch-TensorRT are you using? We have added a compatibility mode for INT64 that should address the issues you see regarding dtypes. The ONNX lowering issue usually appears when the model is not in eval mode.