Initialization of first hidden state in LSTM and truncated BPTT

Hi all,

I am trying to implement my first LSTM with pytorch and hence I am following some tutorials.
In particular I am following:
https://www.deeplearningwizard.com/deep_learning/practical_pytorch/pytorch_lstm_neuralnetwork/
which looks like this:

class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(LSTMModel, self).__init__()
        # Hidden dimensions
        self.hidden_dim = hidden_dim

        # Number of hidden layers
        self.layer_dim = layer_dim

        # Building your LSTM
        # batch_first=True causes input/output tensors to be of shape
        # (batch_dim, seq_dim, feature_dim)
        self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True)

        # Readout layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        # Initialize hidden state with zeros
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        # Initialize cell state
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_()

        # One time step
        # We need to detach as we are doing truncated backpropagation through time (BPTT)
        # If we don't, we'll backprop all the way to the start even after going through another batch
        out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))

        # Index hidden state of last time step
        # out.size() --> 100, 28, 100
        # out[:, -1, :] --> 100, 100 --> just want last time step hidden states! 
        out = self.fc(out[:, -1, :]) 
        # out.size() --> 100, 10
        return out

So my questions are:
1.) Is it common practice to initialize the first hidden state to be a vector of zeros? What alternatives exist and are they better?

2.) Is it normal to newly initialize the first hidden state every time when calling the forward method?

3.) In that example they detach the first hidden state from the computational graph and claim that if they wouldnt do that, the backpropagation would go even back through older batches. But is this really true? Because I have seen no other tutorial of LSTMs or RNNs where they detach the first hidden state. I would have guessed that each time when calling the forward method again a new computational graph is being constructed?

1 Like
  1. Yes, zero initial hiddenstate is standard so much so that it is the default in nn.LSTM if you don’t pass in a hidden state (rather than, e.g. throwing an error). Random initialization could also be used if zeros don’t work. Two basic ideas here:

    • If your hidden state evolution is “ergodic”, the state will move closer to some “steady distribution” anyways, so it doesn’t matter as much.
    • You want the initial hidden state handling to be somewhat consistent between training and inference.
    • The fancy Bayesian way would be to sample from said steady state, but deep learning is too wild to resort to fancy when it isn’t necessary.
  2. For BPTT (aka “feeding in a long sequence bit by bit”), you could keep the last hidden state and use that (detached) as the new initial hidden state if you think that state should be carried between batches. Language model training on Wikipedia (as a common example) will do things like that.

  3. In theory this is really true. In practice you would run out of memory instead. I can’t speak about other tutorials, but those that I have seen do the detaching (or don’t keep state from previous batches) and it would seem necessary to do so.

Best regards

Thomas

P.S.: For the forum: lines with triple backticks ```python at the beginning of your codeand ``` at the end will make it look nice.

15 Likes

@ tom

Thank you very much for your answer. This is very well appreciated.

I have one more question to the 3.), the detaching:
In the example above, the weird thing is that they detach the first hidden state that they have newly created and that they create new again every time they call forward. So in that example the detaching is actually redundant, no? (Since a newly created tensor has no history)
So detaching would actually only be required in a situation like you have described in 2.), where the last hidden state is reused as new initial state, right?

Oh yes, sorry. requires_grad_() + detach() is redundant here and just having having h0 = torch.zeros(...) and passing in that would be more idomatic (in general, it seems strange to use requires_grad_ right after a factory function where could use requires_grad=True, but hey).

Best regards

@ tom great thanks a lot!

@tom
Your answer was really helpful! I just had one question about point 3. you say ‘but those that I have seen do the detaching (or don’t keep state from previous batches) and it would seem necessary to do so’ - in the example shown by @Raphikowski they dont keep state from previous batches, as far as I understand(since they initialize h0,c0 to zeros in forward) - so why is there really a need to detach?

This was worded rather unfortunately.

You options are more or less

  • keep the state from previous batches but detach it,
  • set to zero at each batch (this obviously will not be attached to the last batch),

but you cannot

  • work with (undetached) state that is still connected to the last step: this will either lead to “trying to backward through the graph a second time” errors or, if you instruct PyTorch to keep the graph, memory exhaustion (and probably gradient blow-up, too).

Best regards

Thomas

2 Likes

Huge thanks, @tom for great explanation!

I got a question though. Why should a hidden state require the grad? In the example they pass it in a detached state and hence they are not going to be part of the graph. The hidden state is not supposed to be a parameter to compute derivatives agains right? I’d rather initialize it with requires_grad=False no?

Indeed, if we read this in isolation, it doesn’t make much sense.
If you don’t want to train the initial hidden state, you don’t need the gradient, if you want to train the initial hidden state as part of your model, it would make more sense to just have it as a parameter.

Best regards

Thomas

1 Like