How to jit custom module | ParallelModuleList implementation for GRU

Hey there!

I am trying to implement a model where each ‘pixel’ of the input should run through its own GRU cell (instead of using the feature input of GRU by flattening and flattening the image).
This forces me to loop over a ModuleList of GRU’s to forward pass each pixel.
I am trying to speed it it up using TorchScript but without success. : (
I want to use this custom module during training.

This is my module:

class ParallelGruList(nn.Module):

    def __init__(self, input_size, hidden_size, num_grus, device):
        super().__init__()
        self.num_grus = num_grus
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.device = device  #  passing the device for torch script

        self.grus = nn.ModuleList(self.get_gru() for _ in range(self.num_grus)).to(device)

    def get_gru(self):
        return nn.GRU(input_size=self.input_size,
                      hidden_size=self.hidden_size,
                      num_layers=1,
                      batch_first=True,
                      bias=False)

    def forward(self, x):

        batch_size = x.shape[0]
        sequence_length = x.shape[1]

        output = torch.zeros((batch_size, sequence_length, 10, self.num_grus), dtype=torch.float32, device=self.device)

        for i, gru in enumerate(self.grus):
            output[:, :, :, i], _ = gru(x[:, :, :, i], None)

        return output

And this is how I test it:

def jit_gru(state, grus):
    start = time.time()
    r = grus(state)
    return time.time() - start

    data = torch.rand(1, 100, 3, 14 * 14, dtype=torch.float32, device=device)
    print('init gru')
if __name__ == '__main__':
    print('\n')

    data = torch.rand(1, 100, 3, 14 * 14, dtype=torch.float32, device=device)
    print('init gru')
    grus = ParallelGruList(input_size=3, hidden_size=10, num_grus=14 * 14, device=device).to(device)
    print('run test')
    h = []
    for i in range(10):
        h.append(jit_gru(data, grus))
    print('exe time smart jit', np.mean(h))

Running it like this will generate the following output:

init gru
run test
exe time smart jit 0.2266002893447876

To speed things up I tried to initialize the module by:

grus = torch.jit.script(ParallelGruList(input_size=3, hidden_size=10, num_grus=14 * 14, device=device).to(device))

But this makes it a little slower:

exe time smart jit 0.296830940246582

I also tried to make the ModuleList a constant but this leads to the error:

UserWarning: 'grus' was found in ScriptModule constants,  but it is a non-constant submodule. Consider removing it.
  warnings.warn("'{}' was found in ScriptModule constants, "

I also tried to decorate the get_gru() method with @torch.jit.export but that but that caused the error:

Class GRU does not have an __init__ function defined:

I also tried a couple of other things that made the whole thing usually even slower. I’ve read the JIT reference guide and the documentation of torch.jit.script but I could not find anything that helped.
At this point I am thoroughly clueless.

Could someone point me in the right direction or let me know what I am missing?

Best,
Constantin