BiLSTM hidden states and output don't match

The behavior of the biLSTM is a bit odd and doesn’t quite follow the documentation. Here is an example, first in the unidirectional case, and then in the bidirectional case.

In the unidirectional case, everything works fine. For example, here we create a 500-dimensional (input and hidden state) LSTM with 2 layers:

import torch
rnn = torch.nn.LSTM(500, 500, 2)
emb = torch.autograd.Variable(torch.randn((50, 1, 500)))
out, hidden = rnn(emb, None)
out
Variable containing:
( 0 ,.,.) =
9.4260e-03 -3.5234e-02 -2.4594e-02 … 1.3439e-02 -3.1169e-03 1.3610e-02

( 1 ,.,.) =
9.3988e-03 -2.7440e-02 -3.3170e-02 … 8.4347e-03 1.6242e-02 2.9061e-02

( 2 ,.,.) =
-2.8333e-02 -1.1560e-02 9.7369e-03 … 3.1523e-02 5.0822e-02 3.4342e-02

(47 ,.,.) =
6.1868e-02 -9.9036e-04 -4.7329e-02 … -1.9764e-02 7.2571e-02 -7.0552e-03

(48 ,.,.) =
4.2237e-02 2.8661e-02 -6.1079e-02 … -1.8695e-02 6.7543e-02 7.6148e-03

(49 ,.,.) =
3.9645e-02 -1.0000e-02 -5.0821e-02 … -3.4104e-02 7.6159e-02 1.8242e-03
[torch.FloatTensor of size 50x1x500]
hidden
Variable containing:
( 0 ,.,.) =
-0.0096 0.0774 -0.1312 … -0.0454 0.2775 -0.0578

( 1 ,.,.) =
0.0396 -0.0100 -0.0508 … -0.0341 0.0762 0.0018
[torch.FloatTensor of size 2x1x500]
, Variable containing:
( 0 ,.,.) =
-0.0164 0.2331 -0.2765 … -0.0953 0.4283 -0.1031

( 1 ,.,.) =
0.0801 -0.0208 -0.1008 … -0.0669 0.1577 0.0036
[torch.FloatTensor of size 2x1x500]

This makes sense - the vector corresponding to out[49, :, :] is exactly the hidden[0][1, :, :] vector i.e., the hidden state at the top-most layer and the last time step. However, when I try a bidirectional LSTM, things don’t seem to match:

import torch
rnn = torch.nn.LSTM(500, 500, 2, bidirectional=True)
emb = torch.autograd.Variable(torch.randn((50, 1, 500)))
out, hidden = rnn(emb, None)
out
Variable containing:
( 0 ,.,.) =
4.4536e-02 5.3987e-03 1.4463e-02 … -3.0552e-02 5.9231e-03 -2.9692e-02

( 1 ,.,.) =
5.5412e-02 3.5277e-02 -8.9286e-03 … -4.9081e-02 -2.5681e-02 1.4060e-02

( 2 ,.,.) =
5.5511e-02 6.5030e-02 1.2957e-02 … -3.4862e-02 -6.1621e-02 1.1398e-02

( 47 ,.,.) =
1.0626e-02 5.2022e-02 4.3567e-02 … -1.1319e-02 -1.0329e-02 -2.4761e-02

( 48 ,.,.) =
1.0129e-02 6.4200e-02 5.5245e-02 … 2.9507e-02 -8.1535e-03 -8.9475e-03

( 49 ,.,.) =
3.2189e-02 5.6126e-02 7.0614e-02 … 9.3073e-03 -5.9337e-03 6.3327e-06
[torch.FloatTensor of size 50x1x1000]

It makes sense that the output is of size 1000, due to the concatenation. Now, what I would expect is that out[49, :, :] corresponds to the concatenation of hidden[0][2, :, :] and hidden[0][3, :, :] (as per the documentation). While I find that the first half of out corresponds to this:

torch.sum(out[49, :, :500] - hidden[0][2, :, :])
Variable containing:
0
[torch.FloatTensor of size 1]

the second half however, does not:

torch.sum(out[49, :, 500:] - hidden[0][3, :, :])
Variable containing:
1.8791
[torch.FloatTensor of size 1]

Am I doing something wrong? If out[49, :, 500:] does not correspond to the hidden states from the reverse direction, what do they correspond to?

Avneesh

You’re actually comparing the first output of the reverse direction. I think you want out[0, :, 500:]

6 Likes

Of course, and here I was thinking I had found a bug! I should have realized.

Thanks!

1 Like

Why the last 2 hidden correspond to the first and last output?