BiLSTM : Output & Hidden State Mismatch

Have a look at @vdw’s great visualization and explanation regarding the hidden state and output of an rnn: post.

While he explains the usage for bidirectional=False, we can use the docs to separate the directions and compare the output to the hidden state:

rnn = nn.LSTM(5, 8, 1, bidirectional=True)
h0 = torch.zeros(2*1, 1, 8)
c0 = torch.zeros(2*1, 1, 8)

x = torch.randn(6, 1, 5)
output, (h_n, c_n) = rnn(x, (h0, c0))

# Seperate directions
output = output.view(6, 1, 2, 8) #seq_len, batch, num_directions, hidden_size
h_n = h_n.view(1, 2, 1, 8) # num_layers, num_directions, batch, hidden_size

# Compare directions
output[-1, :, 0] == h_n[:, 0] # forward
output[0, :, 1] == h_n[:, 1] # backward
2 Likes