I have a module, whose purpose is to create word features from characters, by passing characters embeddings through an RNN and taking the last hidden state for each word. I am trying to export the module to ONNX. Here is the forward pass (notice that the RNN is initialized with batch_first=False
):
def forward(self, inputs: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
""""
Args:
inputs (torch.Tensor): Character feature tensor of shape [batch_size x seq_len x max_word_len]
lengths (torch.Tensor): The lengths of each word [batch_size x seq_len]
Returns:
torch.Tensor: The last hidden states of all words.
[batch_size x seq_len x (layers * directions * hidden_size)]
"""
inputs_emb = self.embedding(inputs) # [batch_size x seq_len x max_word_len x embedding_size]
inputs_emb = self.dropout(inputs_emb) # [batch_size x seq_len x max_word_len x embedding_size]
inputs_emb_dims = inputs_emb.size()
inputs_emb_flat = inputs_emb.view(
-1, inputs_emb_dims[2], inputs_emb_dims[3]
) # [(batch_size * seq_len) x max_word_len x embedding_size]
lengths_flat = lengths.view(-1) # [(batch_size * seq_len)]
seq_len_flat = inputs_emb_flat.size()[0] # [(batch_size * seq_len)]
# Move seq dimensions to first prior to RNN
inputs_emb_flat = inputs_emb_flat.permute(1, 0, 2) # [max_word_len, (batch_size * seq_len), embedding_size]
lengths_flat_sorted, lengths_flat_idx_sorted = lengths_flat.sort(dim=0, descending=True)
inputs_emb_flat_sorted = inputs_emb_flat[:, lengths_flat_idx_sorted]
packed = pack_padded_sequence(
inputs_emb_flat_sorted, lengths_flat_sorted.to("cpu"), batch_first=False, enforce_sorted=True
)
if self.rnn.mode == "GRU":
_, h_n = self.rnn(packed) # [(layers * directions) x seq_len_flat x hidden_size]
elif self.rnn.mode == "LSTM":
_, (h_n, _) = self.rnn(packed) # [(layers * directions) x seq_len_flat x hidden_size]
# Move sequence dimension to the first dimension and then
# concatenate the forward/backward hidden states.
h_n = (
h_n.permute(1, 0, 2).contiguous().view(seq_len_flat, -1)
) # [seq_len x (layers * directions * hidden_size)]
# Unsort the hidden states.
_, lengths_flat_idx_unsorted_idx = lengths_flat_idx_sorted.sort(dim=0)
h_n_unsorted = h_n[lengths_flat_idx_unsorted_idx] # [seq_len x (layers * directions * hidden_size)]
# Transform back to original input size
output = h_n_unsorted.view(
inputs_emb_dims[0], inputs_emb_dims[1], -1
) # [batch_size x seq_len x (layers * directions * hidden_size)]
return output
I am using tracing and exporting my model as such:
torch.onnx.export(
model,
input_example,
filepath,
verbose=True,
opset_version=12,
input_names=["char_seqs", "char_seqs_lens"],
output_names=["embeddings"],
dynamic_axes={
"char_seqs": {1: "seq_len"},
"char_seqs_lens": {1: "seq_len"},
"embeddings": {1: "seq_len"},
},
)
When i export the model, I do not get any errors or warnings. However when i try to pass a sequence which has a different length, than the input example, i get the following error:
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Expand node. Name:'Expand_39' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:483 void onnxruntime::BroadcastIterator::Append(int64_t, int64_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 500 by 1951
I do not understand why this is not working, since I believe I am taking all the necessary steps to be able to have a dynamic sequence length.
Edit 1
I believe the reason why this is not working, is because I have both a dynamic sequence length (the length of the word in characters), and the length of the word sequence (which is passed as batch_size x seq_len). Any ideas on how to circumvent this problem, would be most welcome.