This piece of code doesn’t work as the model takes a tuple as input.
import torch
from typing import Tuple
class TupleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._ln1 = torch.nn.Linear(5, 3)
self._ln2 = torch.nn.Linear(5, 3)
def forward(self, x: Tuple[torch.Tensor]):
print(type(x))
y1 = self._ln1(x[0])
y2 = self._ln2(x[1])
return y1, y2
model = TupleModel()
inputs = (torch.zeros([1, 5]), torch.zeros(1, 5))
model(inputs)
assert type(inputs) == tuple
torch.onnx.export(
model,
inputs,
"/tmp/model.onnx",
input_names=["input1", "input2"],
output_names=["output1", "output2"],
)
I got this error:
Traceback (most recent call last):
File "test.py", line 25, in <module>
output_names=["output1", "output2"],
File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/__init__.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 698, in _export
dynamic_axes=dynamic_axes)
File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 456, in _model_to_graph
use_new_jit_passes)
File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 417, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/home/wei/.local/lib/python3.6/site-packages/torch/onnx/utils.py", line 377, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/home/wei/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/home/wei/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/wei/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 130, in forward
self._force_outplace,
File "/home/wei/.local/lib/python3.6/site-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(*trace_inputs))
File "/home/wei/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/home/wei/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
It seems that the onnx export doesn’t take the tuple as a whole instead it takes each element of the tuple as an indivudual argment.