BiLSTM : Output & Hidden State Mismatch

I’m trying to understand the mechanics of the LSTM in Pytorch and came across something that I believe has been asked & answered before but I have a follow-up.

A simple example is pasted below. I expected the final output to be a concatenation of the h_n contents.

But it seems like only the first half matches. From related posts, I see that I’m probably looking at the “first hidden state of the reverse sequence”.

My question : “Is there a way for me to access the last hidden state of the reverse sequence to verify that my ouput[5] is the concatenation of the hidden states?”

Thanks! :slight_smile:

Have a look at @vdw’s great visualization and explanation regarding the hidden state and output of an rnn: post.

While he explains the usage for bidirectional=False, we can use the docs to separate the directions and compare the output to the hidden state:

rnn = nn.LSTM(5, 8, 1, bidirectional=True)
h0 = torch.zeros(2*1, 1, 8)
c0 = torch.zeros(2*1, 1, 8)

x = torch.randn(6, 1, 5)
output, (h_n, c_n) = rnn(x, (h0, c0))

# Seperate directions
output = output.view(6, 1, 2, 8) #seq_len, batch, num_directions, hidden_size
h_n = h_n.view(1, 2, 1, 8) # num_layers, num_directions, batch, hidden_size

# Compare directions
output[-1, :, 0] == h_n[:, 0] # forward
output[0, :, 1] == h_n[:, 1] # backward
2 Likes

I wondered the same thing in a previous post of mine. As @ptrblck clarified in his answer, the last states of the forward and backward pass are indeed found on the “opposite” ends of output.

1 Like

This was very helpful @ptrblck & @vdw ! :slight_smile:

So the output is basically the collection of the hidden states’ step-wise. I get it now.

So when papers say that they use the bidirectional embedding of a sentence in PyTorch, they mean they use the final "hn"i.e “forward_final_hidden_layer, backward_final_hidden_layer” right?

because using the output[-1] is “forward_final_hidden_layer, backward_initial_hidden_layer”, which is incorrect to use as a sentence representation, correct?

Yes, I would strongly assume so, since you only get the true/full sentence embedding after the last step.

Yes, at least the way PyTorch implements it.

An alternative, I have occasionally stumbled up it, is to calculate the average or maximum over the whole output; please see the forward method of a simple RNN-based classifier I used for testing the effects. It obviously will work (since the dimensions check out) and even train well, but I’m pretty sure I saw any improvements since I stopped using it. Fell free to try, though. I might work for you setting

def forward(self, batch, method='last_step'):
    embeds = self.word_embeddings(batch)
    output = torch.transpose(embeds, 0, 1)
    output, self.hidden = self.gru(x, self.hidden)

    if method == 'last_step':
        output = output[-1] # don't do that for bidirectional=True :)
    elif method == 'average_pooling':
        output = torch.sum(output, dim=0) / len(batch[0])
    elif method == 'max_pooling':
        output, _ = torch.max(x, dim=0)
    else:
        raise Exception('Unknown method.')
        
    for l in self.linears:
        output = l(output)
    log_probs = F.log_softmax(output, dim=1)
    return log_probs

Hi @Daniel_Dsouza (and maybe @vdw)

I have been wondering about exactly the same problem. So, I understand that we should use h_ns for “forward_final_hidden_layer, backward_final_hidden_layer”.

However, I still do not get how you were able to tell ‘the output[-1] is “forward_final_hidden_layer, backward_initial_hidden_layer”’. I tried to find this information on the document, but I failed.

I would really appreciate if someone could kindly explain it to me and also what it means “backward_initial_hidden_layer”. Does this --“backward_initial_hidden_layer”-- mean that we get this during the process of getting “forward_final_hidden_layer”?

Thanks in advance!

The issue is that in case of a BiLSTM, the notion of “last hidden state” gets a bit murky.

Take for example the sentence “there will be dragons”. And let’s assume you created your LSTM with batch_first=False. Somewhere in your forward() method you have

output, hidden = lstm(inputs, hidden)

In this case output[-1] gives you the hidden states for the last word (i.e., “dragons”). However, this is only the last state for the forward direction since “dragons” is here the last word. In case of the backwards direction, “dragons” is the first word, so you get the first hidden state w.r.t. to the backwards direction. For the backwards direction the last word is “there”, which is the first word of your sentence. So the last hidden state for the backward direction is somewhere in output[0]. I had the same misunderstanding at first; see a previous post of mine.

hidden does not have any sequence dimension and will contain the last hidden state of the forward direction (part of output[-1]) and the last hidden state of the backward direction (part of output[0]). Depending on your exact task, using hidden for the next layers is usually the right way to go. If you use output[-1] you basically loose all the information from the backward pass since you use the hidden state after only one word and not the while sequence.

3 Likes