How to concatenate the hidden states of a bi-LSTM with multiple layers

Suppose you have a tensor with shape [4, 16, 256], where your LSTM is 2-layer bi-directional (2*2 = 4), the batch size is 16 and the hidden state is 256. What is the correct way to get the concatenated last layer output of the output (shape [16, 512])?

I’m doing the following – please note that I support both GRU and LSTM in the model so I can decide on setup time:

def forward(self, inputs):
    batch_size = inputs.shape[0]
    # Push through embedding layer
    X = self.embedding(inputs)
    # Push through RNN layer (the ouput is irrelevant)
    _, self.hidden = self.rnn(X, self.hidden)

    # Get the hidden state of the last layer of the RNN
    if self.params.rnn_type == RnnType.RNN_TYPE__GRU:
        hidden = self.hidden
    elif self.params.rnn_type == RnnType.RNN_TYPE__LSTM:
        hidden = self.hidden[0]
    # Flatten hidden state with respect to batch size
    hidden = hidden.transpose(1,0).contiguous().view(batch_size, -1)

The important part is the transpose(1,0) to get the batch size to the front. Everything else is handled by the view() command. I only need the transpose since I initialize the RNN with batch_first=True.

Note that the input shape of the directly following linear layer needs to be (rnn_hidden_dim * num_directions * num_layers, output_size).

Thanks for the reply. Suppose in a two stack LSTM, the hidden state of the first layer is pretty much intermediate and I am thinking to get rid of it. It seems your code did not touch this part. Do you have any recommendations on how to do so? From the official doc it is not clear which parts of the hidden output (self.hidden[0] in your example) we should pick.

The output of LSTM is output, (h_n, c_n) in my code _, self.hidden = self.rnn(X, self.hidden), self.hidden is the tuples (h_n, c_n), and since I only want h_n, I have to do hidden = self.hidden[0].

In case you only want the last layer, the docs say that you can separate the hidden state with h_n = h_n.view(num_layers, num_directions, batch, hidden_size. Since num_layers is the first dimension, you only need to to h_n = h_n[-1] to get the last layer. The shape will be (num_directions, batch, hidden_size)