LSTM hidden state logic

Hi,

I am a bit confused about hidden state in LSTM. I am reading this tutorial, and in the forward method of the model, self.hidden is used as inputs h_0.

class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim

        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)

        # The linear layer that maps from hidden state space to tag space
        self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
        self.hidden = self.init_hidden()

    def init_hidden(self):
        # Before we've done anything, we dont have any hidden state.
        # Refer to the Pytorch documentation to see exactly
        # why they have this dimensionality.
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        return (autograd.Variable(torch.zeros(1, 1, self.hidden_dim)),
                autograd.Variable(torch.zeros(1, 1, self.hidden_dim)))

    def forward(self, sentence):
        embeds = self.word_embeddings(sentence)
        lstm_out, self.hidden = self.lstm(
            embeds.view(len(sentence), 1, -1), self.hidden)  # <- here self.hidden is used as h_0
        tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))  
        tag_scores = F.log_softmax(tag_space)
        return tag_scores

Does it mean we are retaining the hidden states for each batch (not timesteps)? Why would one want to do that?
If I want to initialize hidden state (e.g. randomly) but not retaining it for different batch, how should I do it? Thanks.

1 Like

Does it mean we are retaining the hidden states for each batch (not timesteps)? Why would one want to do that?

Yes exactly. Think of the hidden states as the learned weights. If you reset the hidden states after each batch, your network is essentially learning nothing. The hidden states control the gates (input, forget, output) of the LSTM and they carry information about what the network has seen so far. Therefore, your output depends not only on the most recent input, but also data it has seen in the past. This is the whole idea of the LSTM, it “removes” the long-term dependency problem. Read this excellent blog post for further information:
https://colah.github.io/posts/2015-08-Understanding-LSTMs/

If I want to initialize hidden state (e.g. randomly) but not retaining it for different batch, how should I do it?

Well, if you really want to, don’t save the hidden states as a class variable and instead initialiaze them after every batch. This could be done, e.g. by supplying them to your forward function as additional input, and in your train for loop you initialize them after every iteration. Again, I don’t recommend doing that because it would destroy the purpose of an LSTM.

3 Likes

Thank you @engelnico. I think I am still confused on the reason we want to retain hidden state for different batch. Mainly because I read this https://discuss.pytorch.org/t/when-to-call-init-hidden-for-rnn/11518/7:

I think you should call

hidden = net.init_hidden(batch_size)

for every batch because, the hidden state after a batch pass contains information
about the whole previous batch. At test time you’d only have a new hidden state for every sentence so you probably want to train for that.

The argument in this comment is since we will use initial hidden states (say 0) in test time, we would want the model to learn that by implicitly resetting the hidden states for each samples (so the hidden states only pass across time, but not samples). Am I missing anything here? Thanks.

I think i misunderstood your first question, I’m sorry…
There are usually two different modes for LSTM, stateless and stateful.

Stateless Mode updates the parameters for one batch and when the next batch comes, it will initialize the states again (with zeros). Therefore, the “cell memory” is reset after every batch.

Stateful Mode is the one I described above. The final states of the first batch is set to the initial state in the next batch. So it memorizes what was learned before

It is up to you and most importantly the type of data you use, which mode is the right one for you. The question you have to ask: Is your data time dependent across different batches? If yes -> Stateful, if not -> Stateless. And from your initial post, I guess it is not time dependent, so stateless mode is probably the right one for you.

Again sorry for the confusion

5 Likes

Thank you @engelnico. Do you also know the behavior if I do not specify the hidden states (it will set to zero by default)? Is it stateful or stateless?

1 Like

Yes, if for the input (h_0, c_0) is not provided, both h_0 and c_0 default to zero. (source: documentation)