In the documentation for torch.nn.LSTM, the hidden state and cell state inputs h_0 and c_0 are described as “containing the initial hidden state for each element in the input sequence.”
I don’t think this should be correct, as a user would expect h_0 and c_0 to be used for only for the initial LSTM states, so for the first element in the input sequence rather than for each element in the input sequence.
This is the default behaviors for RNNs and GRUs, as you can see from their descriptions for the hidden state h_0, which is “containing the initial hidden state for the input sequence”
So is it true that the same values of h_0 and c_0 will be used repeatedly as the hidden state and cell state for every single element in the sequence for a LSTM? Or is this just incorrect information on the LSTM documentation?
I’m not quite sure if I understood your issue correctly. Yes, h_0 and c_0 form the initial hidden state (strictly speaking the hidden and cell state) of the LSTM layer.
However, the processing works like this:
1st element: use h_0 and element 1 to generate h_1
2nd element: use h_1 and element 2 to generate h_2
3rd element: use h_2 and element 3 to generate h_3
…
n-th element: use h_(n-1) and element n to generate h_n
In short, the hidden state gets updated after each element, just not overwritten but a new hidden state is generated. And at the end, you have access to all hidden states h_1, h_2, …, h_n. For example, if you have in your code a line like:
out, (h, c) = lstm(inputs)
h will be h_n, i.e., the last hidden state – or all last hidden states if you have multiple layers – and out will contain all hidden states, i.e., h_1, h_2, …, h_n but only for the last layer.
As I suspected, this means that h_0 and c_0 are not used for EACH element in the input sequence. Rather, the hidden and cell states from the previous element in the sequence is used.
I was just trying to clarify that LSTMs in Torch indeed works this way as the documentation suggested otherwise, unless I am grossly misinterpreting the meaning of the words “containing the initial hidden state for each element in the input sequence”.