I am having trouble getting a model with several LSTMs to export to ONNX properly. The main issue is that I intend to use the model in an online fashion, i.e. feeding in one frame of data at a time.
My LSTM code is similar to the following:
class MyLSTM(torch.nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.deployed = False
self.hidden = torch.zeros(1, 1, dim_out)
self.cell = torch.zeros(1, 1, dim_out)
self.lstm = torch.nn.LSTM(input_size=dim_in,
hidden_size=dim_out,
batch_first=True,
bidirectional=False)
def deploy(self):
self.deployed = True
def forward(self, x):
if self.deployed:
out, (self.hidden, self.cell) = self.lstm(x, (self.hidden, self.cell))
else:
out, _ = self.lstm(x)
return out
Obviously, before I export the model to ONNX, I call deploy()
. The PyTorch model works as expected, and I even tried saving it as a ScriptModule
with torch.jit
. However, when I load and attempt to use the exported model using onnxruntime
, it’s behavior suggests that it never updates the hidden/cell state.
Is it possible to do this? If so, how?