Onnx export does not work with model that takes tuple as input

This piece of code doesn’t work as the model takes a tuple as input.

import torch
from typing import Tuple

class TupleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._ln1 = torch.nn.Linear(5, 3)
        self._ln2 = torch.nn.Linear(5, 3)

    def forward(self, x: Tuple[torch.Tensor]):
        print(type(x))
        y1 = self._ln1(x[0])
        y2 = self._ln2(x[1])
        return y1, y2

model = TupleModel()
inputs = (torch.zeros([1, 5]), torch.zeros(1, 5))
model(inputs)
assert type(inputs) == tuple
torch.onnx.export(
    model,
    inputs,
    "/tmp/model.onnx",
    input_names=["input1", "input2"],
    output_names=["output1", "output2"],
)

I got this error:

Traceback (most recent call last):
  File "test.py", line 25, in <module>
    output_names=["output1", "output2"],
  File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/__init__.py", line 276, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 94, in export
    use_external_data_format=use_external_data_format)
  File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 698, in _export
    dynamic_axes=dynamic_axes)
  File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 456, in _model_to_graph
    use_new_jit_passes)
  File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 417, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 377, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/home/wei/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/wei/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/wei/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 130, in forward
    self._force_outplace,
  File "/home/wei/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 116, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/wei/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
    result = self._slow_forward(*input, **kwargs)
  File "/home/wei/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
    result = self.forward(*input, **kwargs)

It seems that the onnx export doesn’t take the tuple as a whole instead it takes each element of the tuple as an indivudual argment.