Cannot reproduce gradients of GRU

I made a custom GRU following this custom LSTM. Compared with the Pytorch native GRU, I am able to reproduce the outputs/loss, but not the gradients.

The custom GRU layer is the following:

class GRUCell(jit.ScriptModule):
    __constants__ = ['ngate']
    def __init__(self, input_size, hidden_size):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.ngate = 3
        self.w_ih = Parameter(torch.randn(self.ngate * hidden_size, input_size))
        self.w_hh = Parameter(torch.randn(self.ngate * hidden_size, hidden_size))
        self.b_ih = Parameter(torch.randn(self.ngate * hidden_size))
        self.b_hh = Parameter(torch.randn(self.ngate * hidden_size))

    @jit.script_method
    def forward(self, inputs, hidden):
        # type: (Tensor, Tensor) -> Tensor
        gi = torch.mm(inputs, self.w_ih.t()) + self.b_ih
        gh = torch.mm(hidden, self.w_hh.t()) + self.b_hh
        i_r, i_i, i_n = gi.chunk(self.ngate, 1)
        h_r, h_i, h_n = gh.chunk(self.ngate, 1)
        resetgate = torch.sigmoid(i_r + h_r)
        inputgate = torch.sigmoid(i_i + h_i)
        newgate = torch.tanh(i_n + resetgate * h_n)
        hy = newgate + inputgate * (hidden - newgate)
        return hy

class GRULayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(GRULayer, self).__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, inputs, out):
        # type: (Tensor, Tensor) -> Tensor
        inputs = inputs.unbind(0)
        outputs = torch.jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out = self.cell(inputs[i], out)
            outputs += [out]
        return torch.stack(outputs)

The outputs and gradients are compared:

torch.manual_seed(10)
seq_len = 5
batch = 5
input_size = 3
num_classes = 2
hidden_size = num_classes
criterion = nn.CrossEntropyLoss()
inp = torch.randn(seq_len, batch, input_size)
label = torch.randint(low = 0, high = num_classes, size = (batch,))
state = torch.randn(batch, hidden_size)

rnn = GRULayer(GRUCell, input_size, hidden_size)
out = rnn(inp, state)
out = out[-1]
loss = criterion(out, label)
loss.backward()
gradients = [x.grad for x in rnn.parameters()]

# Control: pytorch native GRU
native_gru = nn.GRU(input_size, hidden_size, 1, batch_first = False)
native_gru_state = state.unsqueeze(0)
for native_gru_param, custom_param in zip(native_gru.all_weights[0], rnn.parameters()):
    assert native_gru_param.shape == custom_param.shape
    with torch.no_grad():
        native_gru_param.copy_(custom_param)
native_gru_out, native_gru_out_state = native_gru(inp, native_gru_state)
native_gru_out = native_gru_out[-1, :, :]
native_gru_loss = criterion(native_gru_out, label)
native_gru_loss.backward()
native_gru_gradients = [x.grad for x in native_gru.all_weights[0]]

print("loss is", loss.item())
print("loss difference is", (loss - native_gru_loss).max().item())
print("gradient differences are")
for x, y in zip(gradients, native_gru_gradients):
    # print(x.abs().max().item())
    # print(y.abs().max().item())
    print((x - y).abs().max().item())

Here are the outputs:

loss is 0.5983431935310364
loss difference is 0.0
gradient differences are
0.16684868931770325
0.10169483721256256
0.08745706081390381
0.06843984127044678

I think the problem lies in GRULayer, as I am able to reproduce the gradients in a single GRUcell (e.g. by setting seq_len = 1). What makes the problem more weird is I am able to reproduce the gradients of the Pytorch native LSTM layers with the original custom LSTM provided in the link.

Your help is very much appreciated!