Request for minimal example of exporting recurrent models (RNN, LSTM, GRU) to ONNX

Hi, I’m trying to just export base recurrent model to ONNX, but seems like I’m missing something about the dimensions ordering of inputs or so. I have no problems with simple forward pass but do have one at torch.onnx.export.

This is my code:

import torch
import torch.onnx

model = torch.nn.GRU(input_size=3,
                     hidden_size=16,
		             num_layers=1)
x = torch.randn(10, 1, 3)
h = torch.zeros(1, 1, 16)

print(model(x, h)) # produces no errors, prints outputs

torch.onnx.export(model, (x, h), 'temp.onnx', export_params=True, verbose=True) # produces the RuntimeError below

Error at running the last line:

Traceback (most recent call last):
  File "pytorch-to-caffe2-via-onnx.py", line 64, in <module>
    run(args)
  File "pytorch-to-caffe2-via-onnx.py", line 44, in run
    torch.onnx.export(model, args=(x, h), f=onnx_proto_output)#, export_params=True, verbose=True)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/onnx/__init__.py", line 25, in export
    return utils.export(*args, **kwargs)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/onnx/utils.py", line 84, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/onnx/utils.py", line 134, in _export
    trace, torch_out = torch.jit.get_trace_graph(model, args)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/jit/__init__.py", line 255, in get_trace_graph
    return LegacyTracedModule(f, nderivs=nderivs)(*args, **kwargs)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 491, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/jit/__init__.py", line 288, in forward
    out = self.inner(*trace_inputs)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/nn/modules/module.py", line 479, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 178, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "/home/lysukhin/distr/anaconda3/envs/pytorch-nightly/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 130, in check_forward_args
    self.input_size, input.size(-1)))
RuntimeError: input.size(-1) must be equal to input_size. Expected 3, got 16

Your code works for my builds 0.5.0a0+4028ff6 and 0.4.0.
Which PyTorch version are you using?

2 Likes

Hi, @ptrblck
I’m using 0.4.0 as well.

UPD: had an error on my side in some other piece of code, sorry.
This is already a working example, thanks @ptrblck.
:slight_smile: