You only need to have a look at the documentation:
- Yeah, the output of an LSTM is
output, (h_n, c_n)
- The shape of
output
is(seq_len, batch, num_directions * hidden_size)
, wherebatch
is the batch size, i.e., the number of sequences in your batch - The shape of
h_n
is(num_layers * num_directions, batch, hidden_size)
; again,batch
being the batch size.
So if you want the all hidden states of the 2nd sequence, you could do output[:,1,:]
or first reshape ouput = ouput.view(0,1)
and then output[1]
.
Please note that you need to be bit more careful in case you define your LSTM with bidirectional=True
; see a previous post of mine (the image in the post my also help with the understanding of the output of an LSTM, although it ignores the batch
dimension).