Hi there, I am trying to understand the math behind a bidirectional RNN. I understand how the forward hidden state is computed but I am having trouble understanding exactly how the backward/ reverse hidden state is calculated.
I first made a vanilla RNN using Pytorch, then I tried to do it manually as shown in the code:
torch.manual_seed(1)
n_in, n_out = 3, 5
inpt = torch.randn([6])
inpt = inpt.view(2, 1, 3)
# print(f'inpt: {inpt}, inpt shape: {inpt.shape}')
vanilla_rnn = nn.RNN(n_in, n_out,bidirectional = True)
out, hx = vanilla_rnn(inpt)
print(f'out: {out}, \n\nout shape: {out.shape}, \n\nhx: {hx}, \n\nhx shape: {hx.shape}')
wih = vanilla_rnn.weight_ih_l0
whh = vanilla_rnn.weight_hh_l0
bih = vanilla_rnn.bias_ih_l0
bhh = vanilla_rnn.bias_hh_l0
wihr = vanilla_rnn.weight_ih_l0_reverse
whhr = vanilla_rnn.weight_hh_l0_reverse
bihr = vanilla_rnn.bias_ih_l0_reverse
bhhr = vanilla_rnn.bias_hh_l0_reverse
with torch.no_grad():
for i in range(hx.shape[0]):
print(hx[i])
print("--------------------------------")
print("--------------------------------")
tanh = nn.Tanh()
# hid = torch.zeros(2, 1, 5, dtype = torch.float32)
hid_forward = torch.zeros(1, 5, dtype = torch.float32)
hid_reverse = torch.zeros(1, 5, dtype = torch.float32)
for i in range(input.shape[0]):
x = input[i]
# Forward
i_forward = x @ torch.transpose(wih, 0, 1) + bih
h_forward = hid_forward @ torch.transpose(whh, 0, 1) + bhh
hid_forward = tanh(i_forward + h_forward)
# Reverse
i_reverse = x @ torch.transpose(wihr, 0, 1) + bihr
h_reverse = torch.flip(h_forward, (0,1)) @ torch.transpose(whhr, 0, 1) + bhhr
hid_reverse = tanh(i_reverse + h_reverse)
print(hid_forward)
print(hid_reverse)
You will see the first tensor of hx matches the ‘hid_forward’ but the second tensor of hx is not matching the ‘hid_reverse’. What am I doing wrong here?