How to export split to ONNX with dynamic split_size?

I need to implement dynamic tensor split op in work. But when I want to export this split op to ONNX with dynamic split_size, it seems not work.

I am new to ONNX. Anyone can help me? Thanks a lot.

To Reproduce

import torch

dummy_input = (torch.tensor([1, 4, 2, 7, 3]), torch.tensor([1, 2, 2]))

class Split(torch.nn.Module):
    def forward(self, x, l):
        return x.split(l.cpu().numpy().tolist(), dim=-1)
    
model = Split()

with torch.no_grad():
    torch.onnx.export(
        model, dummy_input, 'split.onnx', verbose=False, opset_version=13,
        input_names=['a', 'b'],
        output_names=['c'],
        dynamic_axes={'a': [0], 'b': [0], 'c': [0]}
    )

when I use the onnx model, it seems not to work. I get this error:
[ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:b

import onnxruntime as ort

model_path = './split.onnx'
sess = ort.InferenceSession(model_path)

a = torch.tensor([4, 2, 3, 4])
b = torch.tensor([1, 3])
sess.run(['c'], {'a':a.numpy(), 'b':b.numpy()})

Tensor b seems can not be used as an input, but I do need a parameter to represent the dynamic split_size.

Environment

  • PyTorch Version (e.g., 1.0): 1.8.1+cu111
  • Python version: 3.8