Following this guide and this code and inspired by this discussion, I whipped up my own GRU cell implementation using JIT as follows.

```
class JitGRUCell(jit.ScriptModule):
def __init__(self, input_size, hidden_size):
super(JitGRUCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size))
self.weight_hh = Parameter(torch.Tensor(3 * hidden_size, hidden_size))
self.bias_ih = Parameter(torch.Tensor(3 * hidden_size))
self.bias_hh = Parameter(torch.Tensor(3 * hidden_size))
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
@jit.script_method
def forward(self, x, hidden):
# type: (Tensor, Tensor) -> Tensor
x = x.view(-1, x.size(1))
x_results = torch.mm(x, self.weight_ih.t()) + self.bias_ih
h_results = torch.mm(hidden, self.weight_hh.t()) + self.bias_hh
x_results = x_results.squeeze()
h_results = h_results.squeeze()
i_r, i_z, i_n = x_results.chunk(3, 1)
h_r, h_z, h_n = h_results.chunk(3, 1)
r = torch.sigmoid(i_r + h_r)
z = torch.sigmoid(i_z + h_z)
n = torch.tanh(i_n + r * h_n)
return n - torch.mul(n, z) + torch.mul(z, hidden)
```

The implementation itself looks very straightforward, however performance is not great.

What are some tricks that I could use to optimize this implementation further? Iâ€™m very new to JIT script modules and I appreciate each and every suggestion.