I am trying to convert a very simple LSTM model from Pytorch to ONNX. Even after using a batch size of 1 and specifying h0, c0 inputs, I am getting the following warning:
UserWarning: Exporting a model to ONNX with a batch_size other than 1, with a variable length with LSTM can cause an error when running the ONNX model with a different batch size. Make sure to save the model with a batch size of 1, or define the initial states (h0/c0) as inputs of the model.
"or define the initial states (h0/c0) as inputs of the model. ")
Please suggest a valid way to successfully port LSTM from Pytorch to ONNX. Attaching my Pytorch code for reference.
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
lstm = nn.LSTM(3, 3, num_layers=1)
inputs = [torch.randn(1,3) for _ in range(5)] #make a sequence of length 5
inputs = torch.cat(inputs).view(len(inputs), 1, -1)
h0 = torch.randn(1, 1, 3)
c0 = torch.randn(1, 1, 3)
out, (hn, cn) = lstm(inputs, (h0, c0))
input_names = [“input”, “h0”, “c0”]
output_names = [“output”, “hn”, “cn”]
torch.onnx.export(lstm, (inputs, (h0, c0)), “SimpleLSTM_2.onnx”, input_names=input_names, output_names=output_names)