Unable to export character embedder to ONNX with dynamic sequence length

I have a module, whose purpose is to create word features from characters, by passing characters embeddings through an RNN and taking the last hidden state for each word. I am trying to export the module to ONNX. Here is the forward pass (notice that the RNN is initialized with batch_first=False):

def forward(self, inputs: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        inputs (torch.Tensor): Character feature tensor of shape [batch_size x seq_len x max_word_len]
        lengths (torch.Tensor): The lengths of each word [batch_size x seq_len]

        torch.Tensor: The last hidden states of all words.
            [batch_size x seq_len x (layers * directions * hidden_size)]
    inputs_emb = self.embedding(inputs)  # [batch_size x seq_len x max_word_len x embedding_size]
    inputs_emb = self.dropout(inputs_emb)  # [batch_size x seq_len x max_word_len x embedding_size]

    inputs_emb_dims = inputs_emb.size()

    inputs_emb_flat = inputs_emb.view(
        -1, inputs_emb_dims[2], inputs_emb_dims[3]
    )  # [(batch_size * seq_len) x max_word_len x embedding_size]
    lengths_flat = lengths.view(-1)  # [(batch_size * seq_len)]

    seq_len_flat = inputs_emb_flat.size()[0]  # [(batch_size * seq_len)]

    # Move seq dimensions to first prior to RNN
    inputs_emb_flat = inputs_emb_flat.permute(1, 0, 2)  # [max_word_len, (batch_size * seq_len), embedding_size]

    lengths_flat_sorted, lengths_flat_idx_sorted = lengths_flat.sort(dim=0, descending=True)
    inputs_emb_flat_sorted = inputs_emb_flat[:, lengths_flat_idx_sorted]

    packed = pack_padded_sequence(
        inputs_emb_flat_sorted, lengths_flat_sorted.to("cpu"), batch_first=False, enforce_sorted=True
    if self.rnn.mode == "GRU":
        _, h_n = self.rnn(packed)  # [(layers * directions) x seq_len_flat x hidden_size]
    elif self.rnn.mode == "LSTM":
        _, (h_n, _) = self.rnn(packed)  # [(layers * directions) x seq_len_flat x hidden_size]
    # Move sequence dimension to the first dimension and then
    # concatenate the forward/backward hidden states.
    h_n = (
        h_n.permute(1, 0, 2).contiguous().view(seq_len_flat, -1)
    )  # [seq_len x (layers * directions * hidden_size)]

    # Unsort the hidden states.
    _, lengths_flat_idx_unsorted_idx = lengths_flat_idx_sorted.sort(dim=0)
    h_n_unsorted = h_n[lengths_flat_idx_unsorted_idx]  # [seq_len x (layers * directions * hidden_size)]

    # Transform back to original input size
    output = h_n_unsorted.view(
        inputs_emb_dims[0], inputs_emb_dims[1], -1
    )  # [batch_size x seq_len x (layers * directions * hidden_size)]

    return output

I am using tracing and exporting my model as such:

    input_names=["char_seqs", "char_seqs_lens"],
        "char_seqs": {1: "seq_len"},
        "char_seqs_lens": {1: "seq_len"},
        "embeddings": {1: "seq_len"},

When i export the model, I do not get any errors or warnings. However when i try to pass a sequence which has a different length, than the input example, i get the following error:

onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Expand node. Name:'Expand_39' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:483 void onnxruntime::BroadcastIterator::Append(int64_t, int64_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 500 by 1951

I do not understand why this is not working, since I believe I am taking all the necessary steps to be able to have a dynamic sequence length.

Edit 1

I believe the reason why this is not working, is because I have both a dynamic sequence length (the length of the word in characters), and the length of the word sequence (which is passed as batch_size x seq_len). Any ideas on how to circumvent this problem, would be most welcome.

The solution was to make the batch size dynamic as well. This required me to input the initial hidden state to the forward pass as well, when tracing with ONNX.

Hi again,
Would you please help me how you solve your problem. I exactly face the same issue and do not have any solutions.

As I understand, there is a problem during the conversion.
From my experiences, onnx is not converting broadcasting with an acceptable behavior. Then it makes a fixed tensor for this one which results in an acceptable result while inference.