LSTM hidden_states issue

For LSTM in pytorch, first I figure out that the size of LSTM hidden_states is [num_layers*num_direction, batch_size, hidden_features]. So how do I extract the hidden state of the sequence? Now my understanding of doing this is as follow, is that right? I cannot find any explanatory documents about this please help me.

# For multi-layer single directional LSTM network
rnn = nn.LSTM(input_size, hidden_features, num_layers, batch_first=True, bidirectional=False);
input = initializeInput();
# input.size() = [batch_size, seq_len, hidden_features]
hidden_cell = initializeHiddenCell(); 
# hidden_cell.size() = ([num_layers*num_direction, batch_size, hidden_features], [num_layers*num_direction, batch_size, hidden_features])

# input_size == hidden_features
output, hidden_cell = rnn(input, hidden_cell);
hidden_state for current timestep input is: hidden_cell[0][-1];


# For multi-layer bidirectional LSTM network
rnn = nn.LSTM(input_size, hidden_features, num_layers, batch_first=True, bidirectional=False);
input = initializeInput();
# input.size() = [batch_size, seq_len, hidden_features]
hidden_cell = initializeHiddenCell(); 
# hidden_cell.size() = ([num_layers*num_direction, batch_size, hidden_features], [num_layers*num_direction, batch_size, hidden_features])

# input_size == hidden_features
output, hidden_cell = rnn(input, hidden_cell);
forward and backward hidden_state for current timestep input is: hidden_cell[0][-1], hidden_cell[0][-2];

If you look at the docs, you can resolve the num_layers*num_directions with

h = h.view(num_layers, num_directions, batch, hidden_size)

After that, the last layer is simply h[-1], the last layer for the forward pass is h[-1][0], and the last layer for the backward pass is h[-1][1].

Does this answer your question? Iā€™m not quite if I understood what you need to know.

1 Like

Exactly! That is what I am trying to figure out. Thank you so much!