I am having trouble getting a model with several LSTMs to export to ONNX properly. The main issue is that I intend to use the model in an online fashion, i.e. feeding in one frame of data at a time.
My LSTM code is similar to the following:
class MyLSTM(torch.nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.deployed = False self.hidden = torch.zeros(1, 1, dim_out) self.cell = torch.zeros(1, 1, dim_out) self.lstm = torch.nn.LSTM(input_size=dim_in, hidden_size=dim_out, batch_first=True, bidirectional=False) def deploy(self): self.deployed = True def forward(self, x): if self.deployed: out, (self.hidden, self.cell) = self.lstm(x, (self.hidden, self.cell)) else: out, _ = self.lstm(x) return out
Obviously, before I export the model to ONNX, I call
deploy(). The PyTorch model works as expected, and I even tried saving it as a
torch.jit. However, when I load and attempt to use the exported model using
onnxruntime, it’s behavior suggests that it never updates the hidden/cell state.
Is it possible to do this? If so, how?