How to obtain cell states from a bidirectional LSTM in pytorch?

Hello, I would like to know how can we extract the cell states at each time step from a single-layer bidirectional LSTM. From the tutorials and discussions, this snippet works for unidirectional LSTM, but I don’t know how to convert this for the bidirectional one.

lstm = nn.LSTM(3, 3)  # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)]  # make a sequence of length 5
# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
          torch.randn(1, 1, 3))
for i in inputs:
    # Step through the sequence one element at a time.
    # after each step, hidden contains the hidden state.
    out, hidden = lstm(i.view(1, 1, -1), hidden)

This post might be helpful, which visualizes the internal tensors, and this follow-up for the bidirectional case. :wink: