While performing a torch onnx export to the pointnet2 model in: torch-points3d/torch_points3d at master · nicolas-chaulet/torch-points3d · GitHub I found myself with the error:
RuntimeError: ONNX export failed: Couldn’t export Python operator ThreeInterpolate
Which is directly related to the ThreeInterpolate class in torch_points_kernels.torchpoints
class ThreeInterpolate(Function):
@staticmethod
def forward(ctx, features, idx, weight):
# type(Any, torch.Tensor, torch.Tensor, torch.Tensor) -> Torch.Tensor
B, c, m = features.size()
n = idx.size(1)
ctx.three_interpolate_for_backward = (idx, weight, m)
if features.is_cuda:
return tpcuda.three_interpolate(features, idx, weight)
else:
return tpcpu.knn_interpolate(features, idx, weight)
@staticmethod
def backward(ctx, grad_out):
# type: (Any, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
r"""
Parameters
----------
grad_out : torch.Tensor
(B, c, n) tensor with gradients of ouputs
Returns
-------
grad_features : torch.Tensor
(B, c, m) tensor with gradients of features
None
None
"""
idx, weight, m = ctx.three_interpolate_for_backward
if grad_out.is_cuda:
grad_features = tpcuda.three_interpolate_grad(grad_out.contiguous(), idx, weight, m)
else:
grad_features = tpcpu.knn_interpolate_grad(grad_out.contiguous(), idx, weight, m)
return grad_features, None, None
I did read this documentation on how to support torch.autograd.Function based classes: torch.onnx — PyTorch 1.10.0 documentation
but when I add the symbolic method (that ended up looking like this… args are wrong but still I didnt even get to the point where those were the error)
@staticmethod
def symbolic(g, features, idx, weight):
return g.op("ThreeInterpolate", features, idx, weight, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
Now the error is that it doesnt find a ThreeInterpolate function… which leads me to understand that these kinds of functions have to be part of the symbolic_opset in the first place? Is there any workaround suggested to be able to export this function? (at this point all that is required is that it can be exported, the model is already trained)
Full error:
Traceback (most recent call last):
File "forward_scripts/checkpoint_export2.py", line 217, in <module>
main()
File "/venv/lib/python3.8/site-packages/hydra/main.py", line 32, in decorated_main
_run_hydra(
File "/venv/lib/python3.8/site-packages/hydra/_internal/utils.py", line 346, in _run_hydra
run_and_report(
File "/venv/lib/python3.8/site-packages/hydra/_internal/utils.py", line 201, in run_and_report
raise ex
File "/venv/lib/python3.8/site-packages/hydra/_internal/utils.py", line 198, in run_and_report
return func()
File "/venv/lib/python3.8/site-packages/hydra/_internal/utils.py", line 347, in <lambda>
lambda: hydra.run(
File "/venv/lib/python3.8/site-packages/hydra/_internal/hydra.py", line 107, in run
return run_job(
File "/venv/lib/python3.8/site-packages/hydra/core/utils.py", line 129, in run_job
ret.return_value = task_function(task_cfg)
File "forward_scripts/checkpoint_export2.py", line 213, in main
run(model, dataset, device, cfg.output_path)
File "forward_scripts/checkpoint_export2.py", line 132, in run
torch.onnx.export(model,
File "/venv/lib/python3.8/site-packages/torch/onnx/__init__.py", line 225, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 85, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/venv/lib/python3.8/site-packages/torch/onnx/utils.py", line 647, in _export
proto, export_map = graph._export_onnx(
RuntimeError: ONNX export failed: Couldn't export Python operator ThreeInterpolate
Defined at:
/venv/lib/python3.8/site-packages/torch_points_kernels/torchpoints.py(145): three_interpolate
/workdir/forward_scripts/../torch_points3d/core/base_conv/dense.py(140): conv
/workdir/forward_scripts/../torch_points3d/core/base_conv/dense.py(113): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
/workdir/forward_scripts/../torch_points3d/models/base_architectures/unet.py(306): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
/workdir/forward_scripts/../torch_points3d/models/base_architectures/unet.py(304): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
/workdir/forward_scripts/../torch_points3d/models/segmentation/pointnet2.py(208): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
/venv/lib/python3.8/site-packages/torch/jit/_trace.py(116): wrapper
/venv/lib/python3.8/site-packages/torch/jit/_trace.py(125): forward
/venv/lib/python3.8/site-packages/torch/nn/modules/module.py(727): _call_impl
/venv/lib/python3.8/site-packages/torch/jit/_trace.py(1148): _get_trace_graph
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(342): _trace_and_get_graph_from_model
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(379): _create_jit_graph
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(409): _model_to_graph
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(632): _export
/venv/lib/python3.8/site-packages/torch/onnx/utils.py(85): export
/venv/lib/python3.8/site-packages/torch/onnx/__init__.py(225): export
forward_scripts/checkpoint_export2.py(132): run
forward_scripts/checkpoint_export2.py(213): main
/venv/lib/python3.8/site-packages/hydra/core/utils.py(129): run_job
/venv/lib/python3.8/site-packages/hydra/_internal/hydra.py(107): run
/venv/lib/python3.8/site-packages/hydra/_internal/utils.py(347): <lambda>
/venv/lib/python3.8/site-packages/hydra/_internal/utils.py(198): run_and_report
/venv/lib/python3.8/site-packages/hydra/_internal/utils.py(346): _run_hydra
/venv/lib/python3.8/site-packages/hydra/main.py(32): decorated_main
forward_scripts/checkpoint_export2.py(217): <module>