I need to implement dynamic tensor split op in work. But when I want to export this split op to ONNX with dynamic split_size, it seems not work.
I am new to ONNX. Anyone can help me? Thanks a lot.
To Reproduce
import torch
dummy_input = (torch.tensor([1, 4, 2, 7, 3]), torch.tensor([1, 2, 2]))
class Split(torch.nn.Module):
def forward(self, x, l):
return x.split(l.cpu().numpy().tolist(), dim=-1)
model = Split()
with torch.no_grad():
torch.onnx.export(
model, dummy_input, 'split.onnx', verbose=False, opset_version=13,
input_names=['a', 'b'],
output_names=['c'],
dynamic_axes={'a': [0], 'b': [0], 'c': [0]}
)
when I use the onnx model, it seems not to work. I get this error:
[ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:b
import onnxruntime as ort
model_path = './split.onnx'
sess = ort.InferenceSession(model_path)
a = torch.tensor([4, 2, 3, 4])
b = torch.tensor([1, 3])
sess.run(['c'], {'a':a.numpy(), 'b':b.numpy()})
Tensor b seems can not be used as an input, but I do need a parameter to represent the dynamic split_size.
Environment
- PyTorch Version (e.g., 1.0): 1.8.1+cu111
- Python version: 3.8