Trying to avoid onnx:NonZero operation

Hello,
I’m trying to export a model in onnx and to run it with TensorRT.

from polygraphy.backend.trt import EngineFromNetwork, NetworkFromOnnxPath
import torch

class Model(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.x2 = torch.zeros((2048, 1)).cuda()

    def forward(self, x1):
        x2 = self.x2
        idx = x2 < x1
      x1[idx] = x2[idx]
        return x1


if __name__ == '__main__':
    onnx_file = 'test.onnx'

    model = Model()
    x = torch.zeros((2048, 1)).cuda()
    torch.onnx.export(model, x, onnx_file, input_names=['input'], output_names=['output'], opset_version=11)

    build_engine = EngineFromNetwork(NetworkFromOnnxPath(onnx_file))
    engine = build_engine()

While the above model is correctly exported as an onnx, TensorRT has trouble parsing it:

[02/02/2022-17:38:36] [TRT] [E] ModelImporter.cpp:773: While parsing node number 3 [NonZero -> "4"]:
[02/02/2022-17:38:36] [TRT] [E] ModelImporter.cpp:774: --- Begin node ---
[02/02/2022-17:38:36] [TRT] [E] ModelImporter.cpp:775: input: "3"
output: "4"
name: "NonZero_3"
op_type: "NonZero"

[02/02/2022-17:38:36] [TRT] [E] ModelImporter.cpp:776: --- End node ---
[02/02/2022-17:38:36] [TRT] [E] ModelImporter.cpp:779: ERROR: builtin_op_importers.cpp:4870 In function importFallbackPluginImporter:
[8] Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"
[E] In node 3 (importFallbackPluginImporter): UNSUPPORTED_NODE: Assertion failed: creator && "Plugin not found, are the plugin name, version, and namespace correct?"

I’d like to rewrite the code in such a way that this NonZero operation is not present in the onnx graph, is that possible?

This is the onnx graph:

graph(%input : Float(2048, 1, strides=[1, 1], requires_grad=0, device=cuda:0),
      %23 : Long(1, strides=[1], requires_grad=0, device=cpu),
      %24 : Long(1, strides=[1], requires_grad=0, device=cpu)):
  %1 : Float(2048, 1, strides=[1, 1], requires_grad=0, device=cuda:0) = onnx::Constant[value=<Tensor>]() # C:/Users/arosasco/PycharmProjects/pcr/delete2.py:13:0
  %2 : Float(2048, 1, strides=[1, 1], requires_grad=0, device=cuda:0) = onnx::Constant[value=<Tensor>]()
  %3 : Bool(2048, 1, strides=[1, 1], requires_grad=0, device=cuda:0) = onnx::Less(%2, %input) # C:/Users/arosasco/PycharmProjects/pcr/delete2.py:13:0
  %4 : Long(2, *, device=cpu) = onnx::NonZero(%3)
  %5 : Long(*, 2, device=cpu) = onnx::Transpose[perm=[1, 0]](%4)
  %6 : Float(*, strides=[1], requires_grad=0, device=cuda:0) = onnx::GatherND(%1, %5) # C:/Users/arosasco/PycharmProjects/pcr/delete2.py:14:0
  %7 : Long(2, strides=[1], device=cpu) = onnx::Shape(%input)
  %8 : Bool(2048, 1, device=cpu) = onnx::Expand(%3, %7)
  %9 : Long(2, *, device=cpu) = onnx::NonZero(%8)
  %10 : Long(*, 2, device=cpu) = onnx::Transpose[perm=[1, 0]](%9)
  %11 : Long(1, strides=[1], device=cpu) = onnx::Constant[value={-1}]()
  %12 : Float(*, device=cpu) = onnx::Reshape(%6, %11)
  %13 : Long(2, strides=[1], device=cpu) = onnx::Shape(%10)
  %14 : Long(device=cpu) = onnx::Constant[value={0}]()
  %15 : Long(device=cpu) = onnx::Gather[axis=0](%13, %14)
  %18 : Long(1, strides=[1], device=cpu) = onnx::Unsqueeze[axes=[0]](%15)
  %21 : Float(*, device=cpu) = onnx::Slice(%12, %23, %18, %24)
  %output : Float(2048, 1, strides=[1, 1], requires_grad=0, device=cuda:0) = onnx::ScatterND(%input, %10, %21) # C:/Users/arosasco/PycharmProjects/pcr/delete2.py:14:0
  return (%output)

It looks like the problem is around lines 13 and 14 of the above scripts:

idx = x2 < x1
x1[idx] = x2[idx]

I’ve tried to change the first line with torch.zeros_like(x1).to(torch.bool) but the problem persists so I’m thinking the issue is with the second one.

I have no clue on how to solve it, can anyone help?

Apparently changing it with torch.where(x2 < x1, x2, x1) solved the problem.