Have a look at @vdw’s great visualization and explanation regarding the hidden state and output of an rnn: post.
While he explains the usage for bidirectional=False
, we can use the docs to separate the directions and compare the output to the hidden state:
rnn = nn.LSTM(5, 8, 1, bidirectional=True)
h0 = torch.zeros(2*1, 1, 8)
c0 = torch.zeros(2*1, 1, 8)
x = torch.randn(6, 1, 5)
output, (h_n, c_n) = rnn(x, (h0, c0))
# Seperate directions
output = output.view(6, 1, 2, 8) #seq_len, batch, num_directions, hidden_size
h_n = h_n.view(1, 2, 1, 8) # num_layers, num_directions, batch, hidden_size
# Compare directions
output[-1, :, 0] == h_n[:, 0] # forward
output[0, :, 1] == h_n[:, 1] # backward