Why my "manual" GRU is slower than nn.GRU?

I’m trying to build a GRU network using simpler blocks of nn. While using nn.GRU works really fast, my version is painfully slow, even though it also should run on GPU. I’m trying to understand why that is.

Here’s the nn.GRU version:

def __init__(self, in_dim, h_dim, out_dim, n_layers, dropout=0):
        self.rnn = nn.GRU(in_dim, h_dim, n_layers)
        self.decoder = nn.Linear(h_dim, out_dim)
def forward_(self, input: Tensor, hidden_state: Tensor=None):
        output, hidden = self.rnn(input, hidden_state)
        layer_output = self.decoder(output)
        return layer_output, hidden

And here’s my “manual” version:

def __init__(self, in_dim, h_dim, out_dim, n_layers, dropout=0):
        for i in range(n_layers):
            # params = dict()
            if i != 0:
                in_dim = h_dim
            params[f'L{i}W1'] = nn.Linear(in_dim, h_dim)
            params[f'L{i}W2'] = nn.Linear(h_dim, h_dim, bias=False)
            params[f'L{i}W3'] = nn.Linear(in_dim, h_dim)
            params[f'L{i}W4'] = nn.Linear(h_dim, h_dim, bias=False)
            params[f'L{i}W5'] = nn.Linear(in_dim, h_dim)
            params[f'L{i}W6'] = nn.Linear(h_dim, h_dim, bias=False)
        params['W7'] = nn.Linear(h_dim, out_dim)
        for param in params:
            self.add_module(param, params[param])
        self.params = params
def forward(self, input: Tensor, hidden_state: Tensor=None):
    if hidden_state is None:
            layer_states.append(torch.zeros(batch_size, self.h_dim, device=input.device))

    for i in range(self.n_layers):
        layer_input = layer_middle
        layer_middle = torch.zeros(batch_size, seq_len, self.h_dim, device=input.device)
        params = self.params
        for j in range(seq_len):
            x = layer_input[:, j, :]
            z = F.sigmoid(params[f'L{i}W1'](x) + params[f'L{i}W2'](layer_states[i]))
            r = F.sigmoid(params[f'L{i}W3'](x) + params[f'L{i}W4'](layer_states[i]))
            g = F.tanh(params[f'L{i}W5'](x) + r*params[f'L{i}W6'](layer_states[i]))
            layer_states[i] = z*layer_states[i] + (1-z)*g
            layer_middle[:, j, :] = layer_states[i]
        layer_middle = nn.Dropout(self.dropout)(layer_middle)
    layer_output = torch.zeros((batch_size, seq_len, self.out_dim), device=input.device)
    for j in range(seq_len):
        x = layer_middle[:, j, :]
        layer_output[:, j, :] = params['W7'](x)