In the documentation for torch.nn.LSTM, the hidden state and cell state inputs h_0 and c_0 are described as “containing the initial hidden state for each element in the input sequence.”
I don’t think this should be correct, as a user would expect h_0 and c_0 to be used for only for the initial LSTM states, so for the first element in the input sequence rather than for each element in the input sequence.
This is the default behaviors for RNNs and GRUs, as you can see from their descriptions for the hidden state h_0, which is “containing the initial hidden state for the input sequence”
So is it true that the same values of h_0 and c_0 will be used repeatedly as the hidden state and cell state for every single element in the sequence for a LSTM? Or is this just incorrect information on the LSTM documentation?
Thanks in advance.
I’m not quite sure if I understood your issue correctly. Yes,
c_0 form the initial hidden state (strictly speaking the hidden and cell state) of the LSTM layer.
However, the processing works like this:
- 1st element: use
h_0 and element 1 to generate
- 2nd element: use
h_1 and element 2 to generate
- 3rd element: use
h_2 and element 3 to generate
- n-th element: use
h_(n-1) and element n to generate
In short, the hidden state gets updated after each element, just not overwritten but a new hidden state is generated. And at the end, you have access to all hidden states
h_n. For example, if you have in your code a line like:
out, (h, c) = lstm(inputs)
h will be
h_n, i.e., the last hidden state – or all last hidden states if you have multiple layers – and
out will contain all hidden states, i.e.,
h_n but only for the last layer.
Does this clarify things?
As I suspected, this means that h_0 and c_0 are not used for EACH element in the input sequence. Rather, the hidden and cell states from the previous element in the sequence is used.
I was just trying to clarify that LSTMs in Torch indeed works this way as the documentation suggested otherwise, unless I am grossly misinterpreting the meaning of the words “containing the initial hidden state for each element in the input sequence”.