for a fixed length of sequence, say X=[x1,x2,…,x1000]
if you use
lstm_output, lstm_states = nn.LSTM(X)
it may cost 1 seconds
while, if you use LSTMCell
and for loop
,it will cost you about 10 seconds, amost 10x
slower than nn.LSTM.
Here comes the question, you want save each states of all the sequence of LSTM, you have to use nn.LSTMCell
and push the lstm_cell_state
into a list to save.(nn.LSTM
will only return the last cell state out)
If you use GRU, the hidden states of all the sequence is the hidden states, which are just the output of nn.GRU
say
gru_outputs,_ = nn.GRU(X)
in which
gru_outputs=[h1,h2,...h1000]
the gru_outputs is just what you want to retrive
7 Likes