Get the last hidden state in stacked LSTM

Hello everyone,

I’m developing a classifier based on LSTM and I defined the model in this way:

class LSTMClassifier(nn.Module):

    def __init__(self, input_size, seq_len, n_classes, hidden_size, device, dim_feedforward=1024, num_layers=1, dropout=0, bidirectional=False, batch_first=True):
        super(LSTMClassifier, self).__init__()

        self.hidden_size = hidden_size
        self.seq_len = seq_len
        self.input_size = input_size
        self.num_layers = num_layers
        self.device = device

        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, batch_first=batch_first)
        self.classifier = nn.Linear(hidden_size, n_classes)
    
    def init_hidden_state(self, x):
        return (
            torch.zeros(self.num_layers, x.size()[0], self.hidden_size).to(self.device),
            torch.zeros(self.num_layers, x.size()[0], self.hidden_size).to(self.device)
        )
    
    def forward(self, x):
        self.hidden = self.init_hidden_state(x)

        out, (hn, cn) = self.lstm(x, self.hidden)
        hn = hn.view(-1, self.hidden_size)
        y = self.classifier(hn)
        
        return y

It should be correct. If I create a model with hidden_size=64 and num_layers=1, the size of hn (before view) is [1,1,64] when I give in input 1 sample.
If I change num_layers=2, the size of hn (before view) is [2,1,64] because it concatenates the hidden state of the two stacked layers.

My doubt is: what is the order of hidden states in hn? The first element in hn is the hidden state of the first or of the last LSTM layer?
Is it correct to consider the last hidden state of the last LSTM layer to classify the sequences?

Thanks for your help.

Hi there!

I am not sure if this response is still useful. But sharing what I understand about getting the last hidden state from stacked LSTM for future references. If I get the question right, you are interested in using the last hidden states of each element in a batch for a classification task. The variable out in the given code already has that information. In the case of a uni-directional LSTM, it is straightforward. Using out[:, -1] when batch_first=True should work. -1 index refers to the last item in the chain of hidden states. hn[-1] before updating it with view() should also give the same tensor.

However, in the case of bidirectional, follow the note given in the PyTorch documentation:

For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. Example of splitting the output layers when batch_first=False: output.view(seq_len, batch, num_directions, hidden_size).

Also, for reshaping hn, I will recommend hn.view(num_layers, num_directions, batch, hidden_size) and use -1 index for the last layer and 0/1 for selecting the direction.

Finally, about the question of the correctness of choosing the last layer for classification, I believe this can only be answered through empirical analysis. It depends heavily on the task at hand.

Hope this helps.

You can have a look a this code; it handles multiple layers as well as 1 or 2 directions. The key snippet is:

# Extract last hidden state
if self.params.rnn_type == RnnType.GRU:
    final_state = self.hidden.view(self.params.num_layers, self.num_directions, batch_size, self.params.rnn_hidden_dim)[-1]
elif self.params.rnn_type == RnnType.LSTM:
    final_state = self.hidden[0].view(self.params.num_layers, self.num_directions, batch_size, self.params.rnn_hidden_dim)[-1]
# Handle directions
final_hidden_state = None
if self.num_directions == 1:
    final_hidden_state = final_state.squeeze(0)
elif self.num_directions == 2:
    h_1, h_2 = final_state[0], final_state[1]
    # final_hidden_state = h_1 + h_2               # Add both states (requires changes to the input size of first linear layer + attention layer)
    final_hidden_state = torch.cat((h_1, h_2), 1)  # Concatenate both states