ONNX export failed on ATen operator _thnn_fused_gru_cell

Hello,
I am trying to export a Bahdanau Attention RNN model from pytorch to onnx, however I have an issue when trying to convert it. It is most probably caused by the GRUCell layer.
I get this error while trying to convert :

/home/ubuntu/anaconda3/envs/masktextspotter37/lib/python3.7/site-packages/torch/onnx/utils.py:655: UserWarning: ONNX export failed on ATen operator _thnn_fused_gru_cell because torch.onnx.symbolic_opset11._thnn_fused_gru_cell does not exist
  .format(op_name, opset_version, op_name))
Traceback (most recent call last):
  File "tools_onnx/test_net_onnx.py", line 504, in <module>
    main()
  File "tools_onnx/test_net_onnx.py", line 333, in main
    model_decoder_manager.generate_onnx(input_names = ["decoder_input", "decoder_hidden", "encoder_context"])#, output_names = ["decoder_output", "decoder_hidden", "decoder_attention"])#, dynamic_axes = dynamic)
  File "tools_onnx/test_net_onnx.py", line 469, in generate_onnx
    dynamic_axes=dynamic_axes
  File "/home/ubuntu/anaconda3/envs/masktextspotter37/lib/python3.7/site-packages/torch/onnx/__init__.py", line 148, in export
    strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
  File "/home/ubuntu/anaconda3/envs/masktextspotter37/lib/python3.7/site-packages/torch/onnx/utils.py", line 66, in export
    dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
  File "/home/ubuntu/anaconda3/envs/masktextspotter37/lib/python3.7/site-packages/torch/onnx/utils.py", line 416, in _export
    fixed_batch_size=fixed_batch_size)
  File "/home/ubuntu/anaconda3/envs/masktextspotter37/lib/python3.7/site-packages/torch/onnx/utils.py", line 296, in _model_to_graph
    fixed_batch_size=fixed_batch_size, params_dict=params_dict)
  File "/home/ubuntu/anaconda3/envs/masktextspotter37/lib/python3.7/site-packages/torch/onnx/utils.py", line 135, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/home/ubuntu/anaconda3/envs/masktextspotter37/lib/python3.7/site-packages/torch/onnx/__init__.py", line 179, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/masktextspotter37/lib/python3.7/site-packages/torch/onnx/utils.py", line 656, in _run_symbolic_function
    op_fn = sym_registry.get_registered_op(op_name, '', opset_version)
  File "/home/ubuntu/anaconda3/envs/masktextspotter37/lib/python3.7/site-packages/torch/onnx/symbolic_registry.py", line 91, in get_registered_op
    return _registry[(domain, version)][opname]
KeyError: '_thnn_fused_gru_cell'

I guess it’s because the GRUCell is not managed correctly in pytorch onnx, however I saw that the operator “GRU” exists in onnx (documentation available here). I hope there is just a little modification to do in the “symbolic” files.

Here is the model and the code :

class BahdanauAttnDecoderRNN(nn.Module):
    def __init__(
        self,
        hidden_size,
        embed_size,
        output_size,
        n_layers=1,
        dropout_p=0,
        bidirectional=False,
        onehot_size = (8, 32)
    ):
        super(BahdanauAttnDecoderRNN, self).__init__()
        # Define parameters
        self.hidden_size = hidden_size
        self.embed_size = embed_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        # Define layers
        self.embedding = nn.Embedding(output_size, embed_size)
        self.embedding.weight.data = torch.eye(embed_size)
        # self.dropout = nn.Dropout(dropout_p)
        self.word_linear = nn.Linear(embed_size, hidden_size)
        self.attn = Attn("concat", hidden_size, embed_size, onehot_size[0] + onehot_size[1])
        self.rnn = nn.GRUCell(2 * hidden_size + onehot_size[0] + onehot_size[1], hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, word_input, last_hidden, encoder_outputs):
        """
        :param word_input:
            word input for current time step, in shape (B)
        :param last_hidden:
            last hidden stat of the decoder, in shape (layers*direction*B, hidden_size)
        :param encoder_outputs:
            encoder outputs in shape (H*W, B, C)
        :return
            decoder output
        """
        # Get the embedding of the current input word (last output word)
        word_embedded_onehot = self.embedding(word_input).view(
            1, word_input.size(0), -1
        )  # (1,B,embed_size)
        word_embedded = self.word_linear(word_embedded_onehot)  # (1, B, hidden_size)
        attn_weights = self.attn(last_hidden, encoder_outputs)  # (B, 1, H*W)
        context = attn_weights.bmm(
            encoder_outputs.transpose(0, 1)
        )  # (B, 1, H*W) * (B, H*W, C) = (B,1,C)
        context = context.transpose(0, 1)  # (1,B,C)
        # Combine embedded input word and attended context, run through RNN
        # 2 * hidden_size + W + H: 256 + 256 + 32 + 8 = 552
        rnn_input = torch.cat((word_embedded, context), 2)
        last_hidden = last_hidden.view(last_hidden.size(0), -1)
        rnn_input = rnn_input.view(word_input.size(0), -1)
        hidden = self.rnn(rnn_input, last_hidden)
        if not self.training:
            output = F.softmax(self.out(hidden), dim=1)
        else:
            output = F.log_softmax(self.out(hidden), dim=1)
        # Return final output, hidden state
        # print(output.shape)
        return output, hidden, attn_weights


BahdanauAttnDecoderRNN(
  (embedding): Embedding(38, 38)
  (word_linear): Linear(in_features=38, out_features=256, bias=True)
  (attn): Attn(
    (attn): Linear(in_features=552, out_features=256, bias=True)
  )
  (rnn): GRUCell(552, 256)
  (out): Linear(in_features=256, out_features=38, bias=True)
)

If the layer needs to be modified by another one in order for the export to work, can you please tell me if it is possible to transfer the weights of the previous training to the new architecture ?

If you know how to deal with that issue, it would be really helpful !