Suppose you have a tensor with shape [4, 16, 256], where your LSTM is 2-layer bi-directional (2*2 = 4), the batch size is 16 and the hidden state is 256. What is the correct way to get the concatenated last layer output of the output (shape [16, 512])?
I’m doing the following – please note that I support both GRU and LSTM in the model so I can decide on setup time:
def forward(self, inputs): batch_size = inputs.shape # Push through embedding layer X = self.embedding(inputs) # Push through RNN layer (the ouput is irrelevant) _, self.hidden = self.rnn(X, self.hidden) # Get the hidden state of the last layer of the RNN if self.params.rnn_type == RnnType.RNN_TYPE__GRU: hidden = self.hidden elif self.params.rnn_type == RnnType.RNN_TYPE__LSTM: hidden = self.hidden # Flatten hidden state with respect to batch size hidden = hidden.transpose(1,0).contiguous().view(batch_size, -1) ...
The important part is the
transpose(1,0) to get the batch size to the front. Everything else is handled by the
view() command. I only need the
transpose since I initialize the RNN with
Note that the input shape of the directly following linear layer needs to be
(rnn_hidden_dim * num_directions * num_layers, output_size).
Thanks for the reply. Suppose in a two stack LSTM, the hidden state of the first layer is pretty much intermediate and I am thinking to get rid of it. It seems your code did not touch this part. Do you have any recommendations on how to do so? From the official doc it is not clear which parts of the hidden output (self.hidden in your example) we should pick.
The output of LSTM is
output, (h_n, c_n) in my code
_, self.hidden = self.rnn(X, self.hidden),
self.hidden is the tuples
(h_n, c_n), and since I only want
h_n, I have to do
hidden = self.hidden.
In case you only want the last layer, the docs say that you can separate the hidden state with
h_n = h_n.view(num_layers, num_directions, batch, hidden_size. Since
num_layers is the first dimension, you only need to to
h_n = h_n[-1] to get the last layer. The shape will be
(num_directions, batch, hidden_size)