ONNX Export Placeholder for unsupported PyTorch ops?

I’m trying to use ONNX as an intermediate bewteeen Pytorch and TensorRT. When exporting, onnx complains about grid_sample not being supported. Is there a way to get it to throw in a placeholder node with the same args etc for a PyTorch op?

I have previously added ONNX support for a correlation function which somewhat exports correctly. ONNX complains about exporting the autograd function, but I am able to just comment out the autograd call and use the custom op directly just when exporting for it to work correctly. I tried to be cheeky about it and use the raw function directly during self.eval() during exporting instead of the autograd function but it still tries to go through autograd.

I have already written both grid_sample and correlation IPluginV2ExtIO, so this intermediate step is currently the bottleneck (until I start debugging them haha).

1 Like

Finally figured out that I needed to register the custom op without a domain namespace (I’m not sure if that’s strictly the right term) and it was recognised and replaced it correctly. I was trying things like torch::, functional:: or _VariableFunctions:: before and was uncertain if it was just the name I didn’t know or it wasn’t possible.

def grid_sample_op(g, input1, input2, mode, padding_mode, align_corners):
    return g.op("torch::grid_sampler", input1, input2, mode, padding_mode, align_corners)

torch.onnx.register_custom_op_symbolic('::grid_sampler', grid_sample_op, 11)
1 Like

Hey @5had3z, were you able to do inference when you registered grid_sampler? I was able to export a model that uses grid_sample to the ONNX format, but I wasn’t able to do inference with that same model.

I use TensorRT for inference and implemented the required plugins from the PyTorch souce code. You can find my source code for Tensor RT below. If you use ONNX Runtime for inference, I presume you will have to make and register a custom plugin for that. Unfortunately I don’t have any experience with ONNX other than using the graphs as an intermediate representation for Tensor RT to parse so I can’t really help with that.

As an aside (as an improvement on my previous solution), you can ensure correct typing and keyword arguments with the function defined below.

@parse_args('v', 'v', 'i', 'i', 'b')
def grid_sample_op(g, input1, input2, mode, padding_mode, align_corners):
    return g.op("torch::grid_sampler", input1, input2, interpolation_mode_i=mode,
                padding_mode_i=padding_mode, align_corners_i=align_corners)
1 Like