# How does pytorch compute a bidirectional RNN?

Hi there, I am trying to understand the math behind a bidirectional RNN. I understand how the forward hidden state is computed but I am having trouble understanding exactly how the backward/ reverse hidden state is calculated.

I first made a vanilla RNN using Pytorch, then I tried to do it manually as shown in the code:

``````torch.manual_seed(1)
n_in, n_out = 3, 5

inpt = torch.randn([6])

inpt = inpt.view(2, 1, 3)

# print(f'inpt: {inpt}, inpt shape: {inpt.shape}')

vanilla_rnn = nn.RNN(n_in, n_out,bidirectional = True)
out, hx = vanilla_rnn(inpt)

print(f'out: {out}, \n\nout shape: {out.shape}, \n\nhx: {hx}, \n\nhx shape: {hx.shape}')

wih = vanilla_rnn.weight_ih_l0
whh = vanilla_rnn.weight_hh_l0
bih = vanilla_rnn.bias_ih_l0
bhh = vanilla_rnn.bias_hh_l0

wihr = vanilla_rnn.weight_ih_l0_reverse
whhr = vanilla_rnn.weight_hh_l0_reverse
bihr = vanilla_rnn.bias_ih_l0_reverse
bhhr = vanilla_rnn.bias_hh_l0_reverse

with torch.no_grad():
for i in range(hx.shape[0]):
print(hx[i])
print("--------------------------------")
print("--------------------------------")

tanh = nn.Tanh()
# hid = torch.zeros(2, 1, 5, dtype = torch.float32)
hid_forward = torch.zeros(1, 5, dtype = torch.float32)
hid_reverse = torch.zeros(1, 5, dtype = torch.float32)

for i in range(input.shape[0]):
x = input[i]

# Forward
i_forward = x @ torch.transpose(wih, 0, 1) + bih
h_forward = hid_forward @ torch.transpose(whh, 0, 1) + bhh
hid_forward = tanh(i_forward + h_forward)

# Reverse
i_reverse = x @ torch.transpose(wihr, 0, 1) + bihr
h_reverse = torch.flip(h_forward, (0,1)) @ torch.transpose(whhr, 0, 1) + bhhr

hid_reverse = tanh(i_reverse + h_reverse)

print(hid_forward)
print(hid_reverse)
``````

You will see the first tensor of hx matches the ‘hid_forward’ but the second tensor of hx is not matching the ‘hid_reverse’. What am I doing wrong here?

ok so I think I got it. I guess the inputs themselves need to be reversed when passing them through the backward/ reverse direction. Here is the code:

``````with torch.no_grad():
for i in range(hx.shape[0]):
print(hx[i])
print("--------------------------------")
print("--------------------------------")

tanh = nn.Tanh()
output = []
hid_forward = torch.zeros(1, 5, dtype = torch.float32)
hid_reverse = torch.zeros(1, 5, dtype = torch.float32)

# Forward
for i in range(input.shape[0]):
x = input[i]

i_forward = x @ torch.transpose(wih, 0, 1) + bih
h_forward = hid_forward @ torch.transpose(whh, 0, 1) + bhh
hid_forward = tanh(i_forward + h_forward)

# Reverse
reverse_input = torch.flip(input, (0, 1))

for i in range(reverse_input.shape[0]):
x = reverse_input[i]

i_reverse = x @ torch.transpose(wihr, 0, 1) + bihr
h_reverse = hid_reverse @ torch.transpose(whhr, 0, 1) + bhhr
hid_reverse = tanh(i_reverse + h_reverse)

print(hid_forward)
print(hid_reverse)
print("-----------")
``````