RuntimeError: ONNX export failed: Couldn't export Python operator ThreeInterpolate

While performing a torch onnx export to the pointnet2 model in: torch-points3d/torch_points3d at master · nicolas-chaulet/torch-points3d · GitHub I found myself with the error:

RuntimeError: ONNX export failed: Couldn’t export Python operator ThreeInterpolate

Which is directly related to the ThreeInterpolate class in torch_points_kernels.torchpoints

class ThreeInterpolate(Function):
    @staticmethod
    def forward(ctx, features, idx, weight):
        # type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
        B, c, m = features.size()
        n = idx.size(1)
        ctx.three_interpolate_for_backward = (idx, weight, m)
        if features.is_cuda:
            return tpcuda.three_interpolate(features, idx, weight)
        else:
            return tpcpu.knn_interpolate(features, idx, weight)
    @staticmethod
    def backward(ctx, grad_out):
        # type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
        r"""
        Parameters
        ----------
        grad_out : torch.Tensor
            (B, c, n) tensor with gradients of ouputs
        Returns
        -------
        grad_features : torch.Tensor
            (B, c, m) tensor with gradients of features
        None
        None
        """
        idx, weight, m = ctx.three_interpolate_for_backward
        if grad_out.is_cuda:
            grad_features = tpcuda.three_interpolate_grad(grad_out.contiguous(), idx, weight, m)
        else:
            grad_features = tpcpu.knn_interpolate_grad(grad_out.contiguous(), idx, weight, m)
        return grad_features, None, None

I did read this documentation on how to support torch.autograd.Function based classes: torch.onnx — PyTorch 1.10.0 documentation

but when I add the symbolic method (that ended up looking like this… args are wrong but still I didnt even get to the point where those were the error)

    @staticmethod
    def symbolic(g, features, idx, weight):
        return g.op("ThreeInterpolate", features, idx, weight, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))

Now the error is that it doesnt find a ThreeInterpolate function… which leads me to understand that these kinds of functions have to be part of the symbolic_opset in the first place? Is there any workaround suggested to be able to export this function? (at this point all that is required is that it can be exported, the model is already trained)

Full error:

Traceback (most recent call last):
  File "forward_scripts/checkpoint_export2.py", line 217, in <module>
    main()
  File "/venv/lib/python3.8/site-packages/hydra/main.py", line 32, in decorated_main
    _run_hydra(
  File "/venv/lib/python3.8/site-packages/hydra/_internal/utils.py", line 346, in _run_hydra
    run_and_report(
  File "/venv/lib/python3.8/site-packages/hydra/_internal/utils.py", line 201, in run_and_report
    raise ex
  File "/venv/lib/python3.8/site-packages/hydra/_internal/utils.py", line 198, in run_and_report
    return func()
  File "/venv/lib/python3.8/site-packages/hydra/_internal/utils.py", line 347, in <lambda>
    lambda: hydra.run(
  File "/venv/lib/python3.8/site-packages/hydra/_internal/hydra.py", line 107, in run
    return run_job(
  File "/venv/lib/python3.8/site-packages/hydra/core/utils.py", line 129, in run_job
    ret.return_value = task_function(task_cfg)
  File "forward_scripts/checkpoint_export2.py", line 213, in main
    run(model, dataset, device, cfg.output_path)
  File "forward_scripts/checkpoint_export2.py", line 132, in run
    torch.onnx.export(model,
  File "/venv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 225, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 85, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 647, in _export
    proto, export_map = graph._export_onnx(
RuntimeError: ONNX export failed: Couldn't export Python operator ThreeInterpolate

Defined at:
/venv/lib/python3.8/site-packages/torch_points_kernels/torchpoints.py(145): three_interpolate
/workdir/forward_scripts/../torch_points3d/core/base_conv/dense.py(140): conv
/workdir/forward_scripts/../torch_points3d/core/base_conv/dense.py(113): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
/workdir/forward_scripts/../torch_points3d/models/base_architectures/unet.py(306): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
/workdir/forward_scripts/../torch_points3d/models/base_architectures/unet.py(304): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
/workdir/forward_scripts/../torch_points3d/models/segmentation/pointnet2.py(208): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
/venv/lib/python3.8/site-packages/torch/jit/_trace.py(116): wrapper
/venv/lib/python3.8/site-packages/torch/jit/_trace.py(125): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(727): _call_impl
/venv/lib/python3.8/site-packages/torch/jit/_trace.py(1148): _get_trace_graph
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(342): _trace_and_get_graph_from_model
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(379): _create_jit_graph
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(409): _model_to_graph
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(632): _export
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(85): export
/venv/lib/python3.8/site-packages/torch/onnx/__init__.py(225): export
forward_scripts/checkpoint_export2.py(132): run
forward_scripts/checkpoint_export2.py(213): main
/venv/lib/python3.8/site-packages/hydra/core/utils.py(129): run_job
/venv/lib/python3.8/site-packages/hydra/_internal/hydra.py(107): run
/venv/lib/python3.8/site-packages/hydra/_internal/utils.py(347): <lambda>
/venv/lib/python3.8/site-packages/hydra/_internal/utils.py(198): run_and_report
/venv/lib/python3.8/site-packages/hydra/_internal/utils.py(346): _run_hydra
/venv/lib/python3.8/site-packages/hydra/main.py(32): decorated_main
forward_scripts/checkpoint_export2.py(217): <module>