Torchscript for model with GRU

I have a model that has a GRU in it. I have traced the model but when I load it and call the forward method it works only if the length of the sequence I use is the same as the dummy input used when tracing. I get this warning when tracing:

TracerWarning: Iterating over a tensor might cause the trace to be incorrect.
Passing a tensor of different shape won't change the number of iterations 
executed (and might lead to errors or silently ​give incorrect results).

My dummy sequence has a length of 25 with features of size 256 (6400). I chose this at random since the model should work for variable sequence lengths. However, when I load the model and pass a sequence with a different length to the dummy input I get this error:

RuntimeError: shape ‘[-1, 30, 256]’ is invalid for input of size 6400

Is this normal (i.e. what the warning was about) and if so are there any workarounds?

I suppose one workaround is to have a max size and pad the tensor when I use it but this is not optimal.

AFAICT this is exactly tthe case where tracing doesn’t work well. See the Warning section in the torch.jit.trace documentation page:

https://pytorch.org/docs/stable/generated/torch.jit.trace.html#torch.jit.trace


Tracing will not record any control-flow like if-statements or loops.

It also offers a possible solution:

In cases like these, tracing would not be appropriate and scripting is a better choice. If you trace such models, you may silently get incorrect results on subsequent invocations of the model. The tracer will try to emit warnings when doing something that may cause an incorrect trace to be produced.

The problem in my case seems to be that the handling of variable sized tensors (due to variable sequence lengths). I assume then that this is internally implemented with a loop. I will look into scripting but if the loop is in the pytorch code then how would I do it? Another alternative that I just considered is inputting the hidden state for the GRU and passing it to the network which will output the hidden state for use in future steps. That way I suppose I could use use a dummy with length 1 and perform the loop outside. Would something like this work?

I think scripting works fine in that case and is probably the best option:

import torch
import torch.nn as nn

class MyGRU (nn.Module):
  def __init__ (self):
    super(MyGRU, self).__init__()
    self.gru = nn.GRU(10, 32, 5)
  
  def forward (self, x):
    return self.gru(x)

module = torch.jit.script(MyGRU())
print(module(torch.randn(100, 15, 10))[0].shape)
print(module(torch.randn(100, 20, 10))[0].shape)