How does pytorch compute a bidirectional RNN?

Hi there, I am trying to understand the math behind a bidirectional RNN. I understand how the forward hidden state is computed but I am having trouble understanding exactly how the backward/ reverse hidden state is calculated.

I first made a vanilla RNN using Pytorch, then I tried to do it manually as shown in the code:

torch.manual_seed(1)
n_in, n_out = 3, 5

inpt = torch.randn([6])

inpt = inpt.view(2, 1, 3)

# print(f'inpt: {inpt}, inpt shape: {inpt.shape}')

vanilla_rnn = nn.RNN(n_in, n_out,bidirectional = True)
out, hx = vanilla_rnn(inpt)

print(f'out: {out}, \n\nout shape: {out.shape}, \n\nhx: {hx}, \n\nhx shape: {hx.shape}')

wih = vanilla_rnn.weight_ih_l0
whh = vanilla_rnn.weight_hh_l0
bih = vanilla_rnn.bias_ih_l0
bhh = vanilla_rnn.bias_hh_l0

wihr = vanilla_rnn.weight_ih_l0_reverse
whhr = vanilla_rnn.weight_hh_l0_reverse
bihr = vanilla_rnn.bias_ih_l0_reverse
bhhr = vanilla_rnn.bias_hh_l0_reverse

with torch.no_grad():
    for i in range(hx.shape[0]):
        print(hx[i])
print("--------------------------------")
print("--------------------------------")

tanh = nn.Tanh()
# hid = torch.zeros(2, 1, 5, dtype = torch.float32)
hid_forward = torch.zeros(1, 5, dtype = torch.float32)
hid_reverse = torch.zeros(1, 5, dtype = torch.float32)

for i in range(input.shape[0]):
    x = input[i]
    
    # Forward
    i_forward = x @ torch.transpose(wih, 0, 1) + bih
    h_forward = hid_forward @ torch.transpose(whh, 0, 1) + bhh
    hid_forward = tanh(i_forward + h_forward)
    
    # Reverse
    i_reverse = x @ torch.transpose(wihr, 0, 1) + bihr
    h_reverse = torch.flip(h_forward, (0,1)) @ torch.transpose(whhr, 0, 1) + bhhr
    
    hid_reverse = tanh(i_reverse + h_reverse)

print(hid_forward)
print(hid_reverse)

You will see the first tensor of hx matches the ‘hid_forward’ but the second tensor of hx is not matching the ‘hid_reverse’. What am I doing wrong here?

ok so I think I got it. I guess the inputs themselves need to be reversed when passing them through the backward/ reverse direction. Here is the code:

with torch.no_grad():
    for i in range(hx.shape[0]):
        print(hx[i])
print("--------------------------------")
print("--------------------------------")

tanh = nn.Tanh()
output = []
hid_forward = torch.zeros(1, 5, dtype = torch.float32)
hid_reverse = torch.zeros(1, 5, dtype = torch.float32)

# Forward
for i in range(input.shape[0]):
    x = input[i]
    
    i_forward = x @ torch.transpose(wih, 0, 1) + bih
    h_forward = hid_forward @ torch.transpose(whh, 0, 1) + bhh
    hid_forward = tanh(i_forward + h_forward)
    

# Reverse
reverse_input = torch.flip(input, (0, 1))
    
for i in range(reverse_input.shape[0]):
    x = reverse_input[i]
    
    i_reverse = x @ torch.transpose(wihr, 0, 1) + bihr
    h_reverse = hid_reverse @ torch.transpose(whhr, 0, 1) + bhhr
    hid_reverse = tanh(i_reverse + h_reverse)
    
print(hid_forward)
print(hid_reverse)
print("-----------")