I use LSTM to modeling text with the following code, the shape of inputs
is [batch_size, max_seq_len, embedding_size], the shape of input_lens
is [batch_size]. rnn
is simply a bidirectional LSTM defined as follows:
self.rnn = nn.LSTM(self.input_dim, self.hidden_size, self.num_layers, bidirectional=self.bidirectional, dropout=self.dropout, batch_first=True)
`padded_seq_len = inputs.shape[1]
# Sort by length (keep idx)
_, idx_sort = torch.sort(input_lens, dim=0, descending=True)
_, idx_unsort = torch.sort(idx_sort, dim=0)
inputs = inputs.index_select(0, idx_sort)
input_lens = list(input_lens[idx_sort])
# Handling padding in Recurrent Networks
inputs_packed = nn.utils.rnn.pack_padded_sequence(inputs, input_lens, batch_first=True)
self.rnn.flatten_parameters()
if str(self.cell).lower() == 'lstm':
outputs, (hn, cn) = self.rnn(inputs_packed)
elif str(self.cell).lower() == 'gru':
outputs, hn = self.rnn(inputs_packed)
outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, total_length=padded_seq_len)[0]
# Un-sort by length
outputs = outputs.index_select(0, idx_unsort)
hn = hn.index_select(1, idx_unsort)`
However, when I try to export this model to ONNX with dynamic batch_size, I cannot get the parity result with the original pytorch checkpoint. Anyone has such experience?