Cannot reproduce gradients of GRU

I made a custom GRU following this custom LSTM. Compared with the Pytorch native GRU, I am able to reproduce the outputs/loss, but not the gradients.

The custom GRU layer is the following:

class GRUCell(jit.ScriptModule):
    __constants__ = ['ngate']
    def __init__(self, input_size, hidden_size):
        super(GRUCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.ngate = 3
        self.w_ih = Parameter(torch.randn(self.ngate * hidden_size, input_size))
        self.w_hh = Parameter(torch.randn(self.ngate * hidden_size, hidden_size))
        self.b_ih = Parameter(torch.randn(self.ngate * hidden_size))
        self.b_hh = Parameter(torch.randn(self.ngate * hidden_size))

    @jit.script_method
    def forward(self, inputs, hidden):
        # type: (Tensor, Tensor) -> Tensor
        gi = torch.mm(inputs, self.w_ih.t()) + self.b_ih
        gh = torch.mm(hidden, self.w_hh.t()) + self.b_hh
        i_r, i_i, i_n = gi.chunk(self.ngate, 1)
        h_r, h_i, h_n = gh.chunk(self.ngate, 1)
        resetgate = torch.sigmoid(i_r + h_r)
        inputgate = torch.sigmoid(i_i + h_i)
        newgate = torch.tanh(i_n + resetgate * h_n)
        hy = newgate + inputgate * (hidden - newgate)
        return hy

class GRULayer(jit.ScriptModule):
    def __init__(self, cell, *cell_args):
        super(GRULayer, self).__init__()
        self.cell = cell(*cell_args)

    @jit.script_method
    def forward(self, inputs, out):
        # type: (Tensor, Tensor) -> Tensor
        inputs = inputs.unbind(0)
        outputs = torch.jit.annotate(List[Tensor], [])
        for i in range(len(inputs)):
            out = self.cell(inputs[i], out)
            outputs += [out]
        return torch.stack(outputs)

The outputs and gradients are compared:

torch.manual_seed(10)
seq_len = 5
batch = 5
input_size = 3
num_classes = 2
hidden_size = num_classes
criterion = nn.CrossEntropyLoss()
inp = torch.randn(seq_len, batch, input_size)
label = torch.randint(low = 0, high = num_classes, size = (batch,))
state = torch.randn(batch, hidden_size)

rnn = GRULayer(GRUCell, input_size, hidden_size)
out = rnn(inp, state)
out = out[-1]
loss = criterion(out, label)
loss.backward()
gradients = [x.grad for x in rnn.parameters()]

# Control: pytorch native GRU
native_gru = nn.GRU(input_size, hidden_size, 1, batch_first = False)
native_gru_state = state.unsqueeze(0)
for native_gru_param, custom_param in zip(native_gru.all_weights[0], rnn.parameters()):
    assert native_gru_param.shape == custom_param.shape
    with torch.no_grad():
        native_gru_param.copy_(custom_param)
native_gru_out, native_gru_out_state = native_gru(inp, native_gru_state)
native_gru_out = native_gru_out[-1, :, :]
native_gru_loss = criterion(native_gru_out, label)
native_gru_loss.backward()
native_gru_gradients = [x.grad for x in native_gru.all_weights[0]]

print("loss is", loss.item())
print("loss difference is", (loss - native_gru_loss).max().item())
print("gradient differences are")
for x, y in zip(gradients, native_gru_gradients):
    # print(x.abs().max().item())
    # print(y.abs().max().item())
    print((x - y).abs().max().item())

Here are the outputs:

loss is 0.5983431935310364
loss difference is 0.0
gradient differences are
0.16684868931770325
0.10169483721256256
0.08745706081390381
0.06843984127044678

I think the problem lies in GRULayer, as I am able to reproduce the gradients in a single GRUcell (e.g. by setting seq_len = 1). What makes the problem more weird is I am able to reproduce the gradients of the Pytorch native LSTM layers with the original custom LSTM provided in the link.

Your help is very much appreciated!

did you solved this problem?