Tracing GRU layer in a module with jit.export decorator

Hello,

I am trying to trace a model that uses a module with a GRU layer. The module inside the model has a forward and online_inference methods. The model uses only the online_inference method, so it is also traced. The module contains a GRU layer. However, I am getting a strange errors when exposing also the online_inference method to trace.

I isolated this reproducible snippet. Here is the code:

import torch
import torch.nn as nn
from torch import Tensor


class Model(nn.Module):
    
    def __init__(self, hidden_dim = 8, n_layers = 1, batch_first = True):
        super().__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.gru = nn.GRU(hidden_dim, hidden_dim, num_layers=n_layers, batch_first=batch_first)

    def forward(self, c0: Tensor, state: list[Tensor]) -> Tensor:
        gru_h = state[0]
        c, _ = self.gru(c0, gru_h)
        return c

    @torch.jit.export
    def online_inference(self, c0, state: list[Tensor]) -> Tensor:
        gru_h = state[0]
        c, _ = self.gru(c0, gru_h)
        return c


model = Model().eval()

batch = 2
device = "cpu"
inputs = (
    torch.randn(batch, 64, model.hidden_dim) , 
    [torch.zeros(model.n_layers, batch, model.hidden_dim, device=device)]
)
with torch.no_grad():
    trace = torch.jit.trace(model, inputs, check_trace=False, strict=False)

which exits with:

RuntimeError: Couldn’t find method: ‘forward__0’ on class: ‘torch.torch.nn.modules.rnn.___torch_mangle_151.GRU (of Python compilation unit at: 0x564ff038b8f0)’

Removing the @torch.jit.export decorator, the whole online_inference, or gru usage from the online method of course helps, but I dont want to do that. Any ideas what is going on?