Hello, I would like to know how can we extract the cell states at each time step from a single-layer bidirectional LSTM. From the tutorials and discussions, this snippet works for unidirectional LSTM, but I don’t know how to convert this for the bidirectional one.
lstm = nn.LSTM(3, 3) # Input dim is 3, output dim is 3
inputs = [torch.randn(1, 3) for _ in range(5)] # make a sequence of length 5
# initialize the hidden state.
hidden = (torch.randn(1, 1, 3),
torch.randn(1, 1, 3))
for i in inputs:
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(i.view(1, 1, -1), hidden)