Hidden sizes in hidden layers of Bidirectional RNN

Hello everybody!
As I have read about the computation of Bidirectional RNNs (Bidirectional Outputs and Intermediate Inputs) and @vdw mentioned " I think most people just concatenate them before the next step, e.g., pushing through a Linear(2*hidden_size, output_size)". However, when I tried to implement the idea the forward hidden states of the 2nd layer are not equal to the defined bidirectional RNN.

Here’s my code:

First I define the rnn_bidir as the standard Pytorch Bidirectional RNN, and rnn_forward_{order} are the ordinary RNN, so that i can stack their outputs for later investigation.

import torch
rnn_bidir = torch.nn.RNN(25, 50, 4,batch_first = True,bidirectional = True)
rnn_forward_1st = torch.nn.RNN(25, 50, 1,batch_first = True,bidirectional = False)  #First hidden layer (50,25)
rnn_forward_2nd = torch.nn.RNN(100, 50, 1,batch_first = True,bidirectional = False) #2nd hidden layer (50,100), 100 is the concatenated forward and backward hidden states
rnn_forward_3rd = torch.nn.RNN(100, 50, 1,batch_first = True,bidirectional = False) #3rd hidden layer (50,100)
rnn_forward_4th = torch.nn.RNN(100, 50, 1,batch_first = True,bidirectional = False) #4th hidden layer (50,100)

rnn_backward_1st = torch.nn.RNN(25, 50, 1,batch_first = True,bidirectional = False) #Same as forward layers
rnn_backward_2nd = torch.nn.RNN(100, 50, 1,batch_first = True,bidirectional = False)
rnn_backward_3rd = torch.nn.RNN(100, 50, 1,batch_first = True,bidirectional = False)
rnn_backward_4th = torch.nn.RNN(100, 50, 1,batch_first = True,bidirectional = False)

After that, I assign the weights and biases of the rnn_bidir into the individual RNN layers

rnn_forward_1st.weight_ih_l0 = torch.nn.Parameter(rnn_bidir.weight_ih_l0)
rnn_forward_1st.weight_hh_l0 = torch.nn.Parameter(rnn_bidir.weight_hh_l0)
rnn_forward_1st.bias_ih_l0 = torch.nn.Parameter(rnn_bidir.bias_ih_l0)
rnn_forward_1st.bias_hh_l0 = torch.nn.Parameter(rnn_bidir.bias_hh_l0)

rnn_forward_2nd.weight_ih_l0 = torch.nn.Parameter(rnn_bidir.weight_ih_l1)
rnn_forward_2nd.weight_hh_l0 = torch.nn.Parameter(rnn_bidir.weight_hh_l1)
rnn_forward_2nd.bias_ih_l0 = torch.nn.Parameter(rnn_bidir.bias_ih_l1)
rnn_forward_2nd.bias_hh_l0 = torch.nn.Parameter(rnn_bidir.bias_hh_l1)

rnn_forward_3rd.weight_ih_l0 = torch.nn.Parameter(rnn_bidir.weight_ih_l2)
rnn_forward_3rd.weight_hh_l0 = torch.nn.Parameter(rnn_bidir.weight_hh_l2)
rnn_forward_3rd.bias_ih_l0 = torch.nn.Parameter(rnn_bidir.bias_ih_l2)
rnn_forward_3rd.bias_hh_l0 = torch.nn.Parameter(rnn_bidir.bias_hh_l2)

rnn_forward_4th.weight_ih_l0 = torch.nn.Parameter(rnn_bidir.weight_ih_l3)
rnn_forward_4th.weight_hh_l0 = torch.nn.Parameter(rnn_bidir.weight_hh_l3)
rnn_forward_4th.bias_ih_l0 = torch.nn.Parameter(rnn_bidir.bias_ih_l3)
rnn_forward_4th.bias_hh_l0 = torch.nn.Parameter(rnn_bidir.bias_hh_l3)

rnn_backward_1st.weight_ih_l0 = torch.nn.Parameter(rnn_bidir.weight_ih_l0_reverse)
rnn_backward_1st.weight_hh_l0 = torch.nn.Parameter(rnn_bidir.weight_hh_l0_reverse)
rnn_backward_1st.bias_ih_l0 = torch.nn.Parameter(rnn_bidir.bias_ih_l0_reverse)
rnn_backward_1st.bias_hh_l0 = torch.nn.Parameter(rnn_bidir.bias_hh_l0_reverse)

rnn_backward_2nd.weight_ih_l0 = torch.nn.Parameter(rnn_bidir.weight_ih_l1_reverse)
rnn_backward_2nd.weight_hh_l0 = torch.nn.Parameter(rnn_bidir.weight_hh_l1_reverse)
rnn_backward_2nd.bias_ih_l0 = torch.nn.Parameter(rnn_bidir.bias_ih_l1_reverse)
rnn_backward_2nd.bias_hh_l0 = torch.nn.Parameter(rnn_bidir.bias_hh_l1_reverse)

rnn_backward_3rd.weight_ih_l0 = torch.nn.Parameter(rnn_bidir.weight_ih_l2_reverse)
rnn_backward_3rd.weight_hh_l0 = torch.nn.Parameter(rnn_bidir.weight_hh_l2_reverse)
rnn_backward_3rd.bias_ih_l0 = torch.nn.Parameter(rnn_bidir.bias_ih_l2_reverse)
rnn_backward_3rd.bias_hh_l0 = torch.nn.Parameter(rnn_bidir.bias_hh_l2_reverse)

rnn_backward_4th.weight_ih_l0 = torch.nn.Parameter(rnn_bidir.weight_ih_l3_reverse)
rnn_backward_4th.weight_hh_l0 = torch.nn.Parameter(rnn_bidir.weight_hh_l3_reverse)
rnn_backward_4th.bias_ih_l0 = torch.nn.Parameter(rnn_bidir.bias_ih_l3_reverse)
rnn_backward_4th.bias_hh_l0 = torch.nn.Parameter(rnn_bidir.bias_hh_l3_reverse)

Then, I create the initialized hidden states of both forward and backward directions, and then concatenate then alternatively to feed into the rnn_bidir

hidden_forward = torch.randn(4,15,50)
hidden_backward = torch.randn(4,15,50)
hidden_bidir_init = torch.cat((hidden_forward[:1],hidden_backward[:1]),0)
for i,(hid_for,hid_back) in enumerate(zip(hidden_forward[1:],hidden_backward[1:])):
    hidden_bidir_init = torch.cat((hidden_bidir_init,torch.cat((hid_for.unsqueeze(0),hid_back.unsqueeze(0)),0)),0)
input = torch.randn(15,5,25)
out_bidir, hidden_bidir = rnn_bidir(input, hidden_bidir_init)

After that, I first test on the 1st and 2nd layers that I define,

inverse_input = input[:,-1:]
for i,in_ in enumerate(input[:,:-1]):
    inverse_input = torch.cat((inverse_input,input[:,-i-2:-i-1]),1)
    
out_forward_1st ,hid_forward_1st  = rnn_forward_1st(input,hidden_forward[:1])
out_backward_1st,hid_backward_1st = rnn_backward_1st(inverse_input,hidden_backward[:1])

out_forward_2nd ,hid_forward_2nd  = rnn_forward_2nd(torch.cat((out_forward_1st,out_backward_1st),-1),hidden_forward[1:2])
out_backward_2nd,hid_backward_2nd = rnn_backward_2nd(torch.cat((out_forward_1st,out_backward_1st),-1),hidden_backward[1:2])

For the 1st forward and backward layers, it works (the subtraction of hidden_bidir[:1] - hid_forward_1st is zeros, and so does the hid_backward_1st). However, when i step into the 2nd layer, the subtraction is not zeros anymore.

Question:

  1. If @vdw’s opinion is right, then why after concatenating the forward and backward states of the previous layer, it didn’t return the same results as Pytorch’s Bidirectional RNN? If not, could you please tell me the computation flow of the Bidirectional RNN in Pytorch?
  2. Does the backward have to concatenate the output of the forward hidden state as well? If not, why the shape of the weights in the hidden layers of bidirectional RNNs are (hidden_size, hidden_size*2) for both forward and backward direction?
  3. Aside from the computation flow, how does Pytorch handle the initialized hidden_state (D*num_layers, N,h) for backward direction? Is it the initial state of the last timestep in the input sequence (start of the backward direction) - Like the picture below?

Thank you in advance!!

  1. I’m actually note sure how the outputs of the forward and backward pass are handled between layers in case of a Bi-RNN with multiple layers. This might require to look into the code.

2+3) My understanding is that the forward and backward pass are independent. And both start the same initial state (e.g., all zeros).

1 Like

Thank you, Chris.
I tried to implement this idea

I think most people just concatenate them before the next step, e.g., pushing through a Linear(2*hidden_size, output_size).

in my last code block above, by concatenating the output of the first layer of both forward and backward pass to input into the next layer (layer 2):

out_forward_2nd ,hid_forward_2nd  = rnn_forward_2nd(torch.cat((out_forward_1st,out_backward_1st),-1),hidden_forward[1:2])
out_backward_2nd,hid_backward_2nd = rnn_backward_2nd(torch.cat((out_forward_1st,out_backward_1st),-1),hidden_backward[1:2])

as when I access the weights and biases of the Pytorch Bidirectional RNN (rnn_bidir), the input for the next layer of both backward and forward are doubled (the same goes for layer 3,4):
image

But when I subtract the hidden state of the Pytorch Bidirectional RNN with my own bidirectional rnn layer
hidden_bidir[2:3] - hid_forward_2nd The results are not zeros
image
While doing this to the first hidden layer, it works, proving that the weights I assign to each layer that I define is correct:

Could you please tell me where I get your idea wrong? Thank you!

Hi there,
I’ve figured out the solution after trial-and-errors @vdw 's idea! The output of each layer is just like the output of the final layer, which is its input is the concatenation of forward and backward direction (forward 1st and then backward later), BUT in reverse order. You have to reverse the output of the opposite direction for the current direction.
P/s: Thank you so much for your intuition @vdw

inverse_input = input[:,-1:]
for i,_ in enumerate(input[:,:-1]):
    inverse_input = torch.cat((inverse_input,input[:,-i-2:-i-1]),1)
    
out_forward_1st ,hid_forward_1st  = rnn_forward_1st(input,hidden_forward[:1])
out_backward_1st,hid_backward_1st = rnn_backward_1st(inverse_input,hidden_backward[:1])
###This is solution
inverse_out_backward_1st = out_backward_1st[:,-1:]
for i,_ in enumerate(input[:,:-1]):
    inverse_out_backward_1st = torch.cat((inverse_out_backward_1st,out_backward_1st[:,-i-2:-i-1]),1)
    
inverse_out_forward_1st = out_forward_1st[:,-1:]
for i,_ in enumerate(input[:,:-1]):
    inverse_out_forward_1st = torch.cat((inverse_out_forward_1st,out_forward_1st[:,-i-2:-i-1]),1)
###End of solution
out_forward_2nd ,hid_forward_2nd  = rnn_forward_2nd(torch.cat((out_forward_1st,inverse_out_backward_1st),-1),hidden_forward[1:2])
out_backward_2nd,hid_backward_2nd = rnn_backward_2nd(torch.cat((inverse_out_forward_1st,out_backward_1st),-1),hidden_backward[1:2])
1 Like