Cannot recreate backpropagation of `torch.nn.RNN`

Hello.

I’m currently trying to understand Backpropagation Through Time (BPTT), and am trying test my understanding by re-creating the backpropagation calculation that torch.nn.RNN performs.

After having read many articles and performed the mathematical proofs, I am 90% certain that not only my theoretical understanding, but my code implementation is also correct – however, the gradients that I am manually calculating do not match the gradients produced by torch, and I have no idea why.

I’ll show you what I mean. Below is a simple torch_rnn object, which is a simple RNN layer:

import torch.nn as nn

torch_rnn = nn.RNN(input_size=1, hidden_size=2, batch_first=True)

I’ll also be using the following sequence as the input to our RNN:

import torch

x = torch.tensor([[1.0],
                  [2.0],
                  [3.0]])

I’ll now perform a forward pass using torch_rnn, and will store the final hidden state (or the output of the layer) as torch_last_hidden:

_, torch_last_hidden = torch_rnn(x)

Now, I’ll use my own hard-coded implementation to do the same thing, but will copy the weights and biases directly from the torch_rnn object. I’ll then walk-forward through the RNN calculations step by step, until I reach h_3, the final hidden state. I’ll also be initialising h_0 to be all-zero; as is the default behaviour for nn.RNN:

W_ih = torch_rnn.weight_ih_l0.detach()
b_ih = torch_rnn.bias_ih_l0.detach()
W_hh = torch_rnn.weight_hh_l0.detach()
b_hh = torch_rnn.bias_hh_l0.detach()

h_0 = torch.zeros(1, 64)

tanh = torch.nn.Tanh()

z_1 = (x[0] @ W_ih.T) + b_ih + (h_0 @ W_hh.T) + b_hh
h_1 = tanh(z_1)

z_2 = (x[1] @ W_ih.T) + b_ih + (h_1 @ W_hh.T) + b_hh
h_2 = tanh(z_2)

z_3 = (x[2] @ W_ih.T) + b_ih + (h_2 @ W_hh.T) + b_hh
h_3 = tanh(z_3)

Let’s make sure that our manual calculation has produced a very close approximation of the torch implementation:

assert torch.allclose(h_3, torch_last_hidden, atol=1e-6)

^^ This assert statement doesn’t throw an error, so the forward pass has been successfully recreated.

Next, is the backward pass. For this example, I will try and calculate the gradients of W_ih with respect to h_3. The theory behind how to do this (as I understand it), is to calculate the following:

I’ll now try applying this to our above problem and manually perform the backpropagation for W_ih. I’ll also allow torch to perform its own calculation of the W_ih gradients for comparison, and save the result to torch_W_ih_gradients:

import torch.optim as optim

optimiser = optim.SGD(torch_rnn.parameters(), lr=0.001)

optimiser.zero_grad()

torch_last_hidden.sum().backward()
torch_W_ih_gradients = copy.deepcopy(torch_rnn.weight_ih_l0.grad)

dh_3_dW_ih_3 = (1 - tanh(z_3)**2)*x[2]

dh_3_dW_ih_2 = (((1 - tanh(z_3)**2) @ W_hh) * 
                ((1 - tanh(z_2)**2) * x[1]))

dh_3_dW_ih_1 = (((1 - tanh(z_3)**2) @ W_hh) * 
                ((1 - tanh(z_2)**2) @ W_hh) * 
                ((1 - tanh(z_1)**2) * x[0]))

dh_3_dW_ih = dh_3_dW_ih_1 + dh_3_dW_ih_2 + dh_3_dW_ih_3

assert torch.allclose(dh_3_dW_ih.T, torch_W_ih_gradients, atol=1e-6)

The above code will throw an AssertionError, and a manual inspection of the gradients reveals that they are in fact different:

dh_3_dW_ih.T -------------> tensor([[0.5162],
                                    [0.5511]])
                                    
torch_W_ih_gradients -----> tensor([[0.5053],
                                    [0.5356]]))

The difference is small, but significant enough to cause concern.
Have I made a mistake with my theoretical understanding of the backpropagation in the RNN?
Is there a mistake with my code implementation?
Why are my calculations not yielding the same gradients as the torch implementation?