Hello,
I am trying to trace a model that uses a module with a GRU layer. The module inside the model has a forward
and online_inference
methods. The model uses only the online_inference
method, so it is also traced. The module contains a GRU layer. However, I am getting a strange errors when exposing also the online_inference
method to trace.
I isolated this reproducible snippet. Here is the code:
import torch
import torch.nn as nn
from torch import Tensor
class Model(nn.Module):
def __init__(self, hidden_dim = 8, n_layers = 1, batch_first = True):
super().__init__()
self.n_layers = n_layers
self.hidden_dim = hidden_dim
self.gru = nn.GRU(hidden_dim, hidden_dim, num_layers=n_layers, batch_first=batch_first)
def forward(self, c0: Tensor, state: list[Tensor]) -> Tensor:
gru_h = state[0]
c, _ = self.gru(c0, gru_h)
return c
@torch.jit.export
def online_inference(self, c0, state: list[Tensor]) -> Tensor:
gru_h = state[0]
c, _ = self.gru(c0, gru_h)
return c
model = Model().eval()
batch = 2
device = "cpu"
inputs = (
torch.randn(batch, 64, model.hidden_dim) ,
[torch.zeros(model.n_layers, batch, model.hidden_dim, device=device)]
)
with torch.no_grad():
trace = torch.jit.trace(model, inputs, check_trace=False, strict=False)
which exits with:
RuntimeError: Couldn’t find method: ‘forward__0’ on class: ‘torch.torch.nn.modules.rnn.___torch_mangle_151.GRU (of Python compilation unit at: 0x564ff038b8f0)’
Removing the @torch.jit.export
decorator, the whole online_inference
, or gru usage from the online method of course helps, but I dont want to do that. Any ideas what is going on?