Indexing or masking in ONNX

Description

Hi,
I am trying to do simple task which is indexing or masking and export it to onnx while using “ONNX_ATEN_FALLBACK” operator at pytorch, I can’t find a way to do it
I tried all of the following solutions :

  • The normal masking " [mask]" but it gives an Aten operation during conversion from pytorch to onnx
  • The normal indexing “x[arange()][idx]” but it gives me the same result at the previous one
  • Gather function but it is not supported at TRT yet
  • Maked_select function but it doesn’t work because it is converted to “Nonzero ,Expand” which not supported at TRT

So is there any other solutions without using external written c++ functions?

Note : I have to use “ONNX_ATEN_FALLBACK” becuase I am using external nms " batchedNMSPlugin"

So the indexing working in right way without ONNX_ATEN_FALLBACK but I have to use it since I am using external ref .

this is simple piece of code to regenerate the error.

import torch
import torch.nn as nn
class TestModel(nn.Module):

    def __init__(self):
        super(TestModel, self).__init__()

    def forward(self,args ):
        dummy_input,idx=args
        dummy_input=dummy_input[torch.arange(300)][idx]
        return dummy_input


torch_model = TestModel()
dummy_input = torch.randn(( 300, 44))
idx = torch.tensor([1, 2])
torch_model([dummy_input,idx])
torch_out = torch.onnx.export(torch_model,
                              [dummy_input,idx],
                              'test_model.onnx',
                              verbose=True,
                              opset_version=11,
                              operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
                              )

The output using ONNX_ATEN_FALLBACK :

graph(%0 : Float(300:44, 44:1, requires_grad=0, device=cpu),
      %1 : Long(2:1, requires_grad=0, device=cpu)):
  %2 : Long(300:1, requires_grad=0, device=cpu) = onnx::Constant[value=<Tensor>]()
  %3 : Tensor[] = onnx::SequenceConstruct(%2)
  %4 : Float(300:44, 44:1, requires_grad=0, device=cpu) = onnx::ATen[operator="index"](%0, %3) # /volumes1/PycharmProjects/od-release/object_detection_framework/devug_onxx.py:27:0
  %5 : Long(2:1, requires_grad=0, device=cpu) = onnx::Cast[to=7](%1) # /volumes1/PycharmProjects/od-release/object_detection_framework/devug_onxx.py:27:0
  %6 : Tensor?[] = onnx::SequenceConstruct(%5)
  %7 : Float(2:44, 44:1, requires_grad=0, device=cpu) = onnx::ATen[operator="index"](%4, %6) # /volumes1/PycharmProjects/od-release/object_detection_framework/devug_onxx.py:27:0
  return (%7)

The output without using ONNX_ATEN_FALLBACK :

Pytorch Version 1.7.1
graph(%0 : Float(300:44, 44:1, requires_grad=0, device=cpu),
      %1 : Long(2:1, requires_grad=0, device=cpu)):
  %2 : Long(300:1, requires_grad=0, device=cpu) = onnx::Constant[value=<Tensor>]()
  %3 : Float(300:44, 44:1, requires_grad=0, device=cpu) = onnx::Gather[axis=0](%0, %2) # /volumes1/PycharmProjects/od-release/object_detection_framework/devug_onxx.py:27:0
  %4 : Long(2:1, requires_grad=0, device=cpu) = onnx::Cast[to=7](%1) # /volumes1/PycharmProjects/od-release/object_detection_framework/devug_onxx.py:27:0
  %5 : Float(2:44, 44:1, requires_grad=0, device=cpu) = onnx::Gather[axis=0](%3, %4) # /volumes1/PycharmProjects/od-release/object_detection_framework/devug_onxx.py:27:0
  return (%5)

Environment

TensorRT Version: TensorRT 7.2.2.3
GPU Type: GeForce RTX 2080 Ti/PCIe/SSE2
Nvidia Driver Version: release 460.32.03
CUDA Version: NVIDIA CUDA 11.2.1
CUDNN Version: NVIDIA cuDNN 8.1.0
Operating System + Version: Ubuntu 18.04.3 LTS
Python Version (if applicable): python 3.6 and 3.8 (respectively)
PyTorch Version (if applicable): pytorch 1.7.1 and pytorch 1.8(respectively)
Baremetal or Container (if container which image + tag): [TensorRT Release 21.03]
OPSET version:10,11,12,13, 9(doesn’t work because of upsamle layer"