How to obtain cell states from a bidirectional LSTM in pytorch?

Hello, I would like to know how can we extract the cell states at each time step from a single-layer bidirectional LSTM. From the tutorials and discussions, this snippet works for unidirectional LSTM, but I don’t know how to convert this for the bidirectional one.

lstm = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5
# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)

This post might be helpful, which visualizes the internal tensors, and this follow-up for the bidirectional case. :wink:

There are 2 main concepts with LSTMs:

  1. output: PyTorch returns the final output corresponding to each time step (sequence length) in both directions. The output is returned only for the last layer in a multi-layer LSTM.
  2. hidden state: PyTorch returns only the final hidden state after the last time step (after the last element in the sequence) is processed. However, this final hidden state for each layer is returned in case of a multi-layer LSTM.

To obtain the hidden state after the last layer in both directions use hidden[-2] and hidden[-1] respectively for the forward and reverse directions.