How to handle last batch in LSTM hidden state


(Anjandeep Sahni) #1

I am trying to setup a simple RNN using LSTM. I am facing issue with passing the hidden state of RNN from one batch to another. For example, if my batch_size = 64, and I am using batch_first = True, hidden_size = 100 and nlayers = 3.

__init__():
    self.rnn = nn.LSTM(input_size = 40, hidden_size=100, num_layers = 3, batch_first=True)
    self.hidden = None
..
..
forward():
    out_pkd, (h, c) = self.rnn(in_pkd, self.hidden)
    self.hidden = (h.detach(), c.detach())

Then the shape of my hidden state is [3, 64, 100]. Now if my last batch is smaller than the batch_size 64, say 55, then I get the following error because the hidden state from the second last batch was of size [3, 64, 100].

RuntimeError: Expected hidden[0] size (3, 55, 100), got (3, 64, 100)

Above error occurs at line:

out_pkd, (h, c) = self.rnn(in_pkd, self.hidden)

How do I handle the case of smaller batch size for last batch in this case?


(Chris) #2

You need to reset/initialize your hidden state before each forward pass, in this step you can adjust it to your current batch size. The most common pattern is to have a method in your model class like:

def init_hidden(self, batch_size):
    if self.rnn_type == 'gru':
        return torch.zeros(self.num_layers * self.directions_count, batch_size, self.rnn_hidden_dim).to(self.device)
    elif self.rnn_type == 'lstm':
        return (torch.zeros(self.num_layers * self.directions_count, batch_size, self.rnn_hidden_dim).to(self.device),
                torch.zeros(self.num_layers * self.directions_count, batch_size, self.rnn_hidden_dim).to(self.device))
    else:
        raise Exception('Unknown rnn_type. Valid options: "gru", "lstm"')

Note that this one is already flexible enough to support different types of RNNs, different number of layers, and whether the RNN layer is bidirectional or not. The method can be much simple if you don’t need this.

Assuming that you call forward() with a parameter inputs that contains the current batch with a shape (batch_size, ...). You can call self.init_hidden(inputs[0]) as first statement in your forward() method – or before the forward() method in the loop that iterates over the batches.

That should fix your problem.


(Anjandeep Sahni) #3

Thanks, this helps. I actually also want to learn the initial hidden state sent to the LSTM by wrapping self.hidden in nn.Parameter. In that case, I don’t think it would be right to reinit with zeros each time, right?

If then I initialized self.hidden to

self.hidden = nn.Parameter(torch.zeros(self.num_layers * self.directions_count, batch_size, self.rnn_hidden_dim))

and then did backprop for current batch, how do I feed this same hidden state to the LSTM during next batch?


(Chris) #4

You can re-init the hidden state with 0s each time; it’s a common practice. For learning a (maybe) better initial hidden state, you can have a look at this earlier post.


(Anjandeep Sahni) #5

Thanks a lot. That post was quite helpful.