The issue is that in case of a BiLSTM, the notion of “last hidden state” gets a bit murky.
Take for example the sentence “there will be dragons”. And let’s assume you created your LSTM with batch_first=False
. Somewhere in your forward()
method you have
output, hidden = lstm(inputs, hidden)
In this case output[-1]
gives you the hidden states for the last word (i.e., “dragons”). However, this is only the last state for the forward direction since “dragons” is here the last word. In case of the backwards direction, “dragons” is the first word, so you get the first hidden state w.r.t. to the backwards direction. For the backwards direction the last word is “there”, which is the first word of your sentence. So the last hidden state for the backward direction is somewhere in output[0]
. I had the same misunderstanding at first; see a previous post of mine.
hidden
does not have any sequence dimension and will contain the last hidden state of the forward direction (part of output[-1]
) and the last hidden state of the backward direction (part of output[0]
). Depending on your exact task, using hidden
for the next layers is usually the right way to go. If you use output[-1]
you basically loose all the information from the backward pass since you use the hidden state after only one word and not the while sequence.