JIT GRU Implementation Optimization

Following this guide and this code and inspired by this discussion, I whipped up my own GRU cell implementation using JIT as follows.

class JitGRUCell(jit.ScriptModule):
    def __init__(self, input_size, hidden_size):
        super(JitGRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = Parameter(torch.Tensor(3 * hidden_size, input_size))
        self.weight_hh = Parameter(torch.Tensor(3 * hidden_size, hidden_size))
        self.bias_ih = Parameter(torch.Tensor(3 * hidden_size))
        self.bias_hh = Parameter(torch.Tensor(3 * hidden_size))

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    @jit.script_method
    def forward(self, x, hidden):
        # type: (Tensor, Tensor) -> Tensor
        x = x.view(-1, x.size(1))
        x_results = torch.mm(x, self.weight_ih.t()) + self.bias_ih
        h_results = torch.mm(hidden, self.weight_hh.t()) + self.bias_hh

        x_results = x_results.squeeze()
        h_results = h_results.squeeze()

        i_r, i_z, i_n = x_results.chunk(3, 1)
        h_r, h_z, h_n = h_results.chunk(3, 1)

        r = torch.sigmoid(i_r + h_r)
        z = torch.sigmoid(i_z + h_z)
        n = torch.tanh(i_n + r * h_n)

        return n - torch.mul(n, z) + torch.mul(z, hidden)

The implementation itself looks very straightforward, however performance is not great.

What are some tricks that I could use to optimize this implementation further? I’m very new to JIT script modules and I appreciate each and every suggestion.

1 Like

hi, Maghoumi. I’ve suffered form the same problem as you, do you fix it?

Sorry for the super late response. I just randomly saw this thread again. Nope… Still using the same code. Did you solve this problem?

What do you mean “return h”? IIRC the output of a GRU cell is the hidden state itself (you feed it back to itself as the previous hidden state)

Also, I had some unit tests in place, and could confirm that this code returned close-enough results to PyTorch’s GRUCell implementation

yep. you are right. good thing i started with i moght be wrong …:wink:

1 Like

By the way, here’s my full implementation (including support for multiple layers, etc.): https://github.com/Maghoumi/JitGRU

Suggestions for improvement are greatly welcome!

If you are comparing jit version to default version, there certainly will be a gap, especially during backward process. I noticed that both the implementation from here and your implementation used +=[i], but I recommend using .append(i). I don’t know how are these two methods translated to C++ but it seems that .append(i) is faster than simply adding, especially when you want to add many elements.