How to retrieve hidden states for all time steps in LSTM or BiLSTM?

for a fixed length of sequence, say X=[x1,x2,…,x1000]
if you use
lstm_output, lstm_states = nn.LSTM(X) it may cost 1 seconds
while, if you use LSTMCell and for loop,it will cost you about 10 seconds, amost 10x slower than nn.LSTM.
Here comes the question, you want save each states of all the sequence of LSTM, you have to use nn.LSTMCell and push the lstm_cell_state into a list to save.(nn.LSTM will only return the last cell state out)
If you use GRU, the hidden states of all the sequence is the hidden states, which are just the output of nn.GRU
say
gru_outputs,_ = nn.GRU(X)
in which
gru_outputs=[h1,h2,...h1000]
the gru_outputs is just what you want to retrive

7 Likes