How to handle last batch in LSTM hidden state

I am trying to setup a simple RNN using LSTM. I am facing issue with passing the hidden state of RNN from one batch to another. For example, if my batch_size = 64, and I am using batch_first = True, hidden_size = 100 and nlayers = 3.

__init__():
    self.rnn = nn.LSTM(input_size = 40, hidden_size=100, num_layers = 3, batch_first=True)
    self.hidden = None
..
..
forward():
    out_pkd, (h, c) = self.rnn(in_pkd, self.hidden)
    self.hidden = (h.detach(), c.detach())

Then the shape of my hidden state is [3, 64, 100]. Now if my last batch is smaller than the batch_size 64, say 55, then I get the following error because the hidden state from the second last batch was of size [3, 64, 100].

RuntimeError: Expected hidden[0] size (3, 55, 100), got (3, 64, 100)

Above error occurs at line:

out_pkd, (h, c) = self.rnn(in_pkd, self.hidden)

How do I handle the case of smaller batch size for last batch in this case?

4 Likes

You need to reset/initialize your hidden state before each forward pass, in this step you can adjust it to your current batch size. The most common pattern is to have a method in your model class like:

def init_hidden(self, batch_size):
    if self.rnn_type == 'gru':
        return torch.zeros(self.num_layers * self.directions_count, batch_size, self.rnn_hidden_dim).to(self.device)
    elif self.rnn_type == 'lstm':
        return (torch.zeros(self.num_layers * self.directions_count, batch_size, self.rnn_hidden_dim).to(self.device),
                torch.zeros(self.num_layers * self.directions_count, batch_size, self.rnn_hidden_dim).to(self.device))
    else:
        raise Exception('Unknown rnn_type. Valid options: "gru", "lstm"')

Note that this one is already flexible enough to support different types of RNNs, different number of layers, and whether the RNN layer is bidirectional or not. The method can be much simple if you don’t need this.

Assuming that you call forward() with a parameter inputs that contains the current batch with a shape (batch_size, ...). You can call self.init_hidden(inputs[0]) as first statement in your forward() method – or before the forward() method in the loop that iterates over the batches.

That should fix your problem.

7 Likes

Thanks, this helps. I actually also want to learn the initial hidden state sent to the LSTM by wrapping self.hidden in nn.Parameter. In that case, I don’t think it would be right to reinit with zeros each time, right?

If then I initialized self.hidden to

self.hidden = nn.Parameter(torch.zeros(self.num_layers * self.directions_count, batch_size, self.rnn_hidden_dim))

and then did backprop for current batch, how do I feed this same hidden state to the LSTM during next batch?

You can re-init the hidden state with 0s each time; it’s a common practice. For learning a (maybe) better initial hidden state, you can have a look at this earlier post.

1 Like

Thanks a lot. That post was quite helpful.

@sahni
I tried to use your code of passing hidden states accross batches.
I see that i will face the last batch issue as well. But let’s keep that aside for now.

I tried to debug using pycharm. And I observed that self.hidden was getting reinitialized to None everytime.
Batch 1 : Starts with None , than changes according to (h,c)
Batch 2 : Starts with None again :? How is this possible ?

Hi, @vdw I still don’t get it - why it is right to pass 0s as hidden state each time? Thanks!

1 Like

There’s nothing inherently special about 0s, it’s merely a way to represent “no prior knowledge”.

For example, when you want to classify individual sentences, your batches are completely independent – that is, the outcome for the sentences in Batch 2 should not depend in the outcome of Batch 1. Of course, if you’re sentences depend on each other, then you don’t want to re-initialize for each batch.

You don’t have to re-initialize to 0s anyway. You can always try without (you probably still have to detach it after each batch) and see how it effects the results. You can also re-initialize with random values for each batch.

3 Likes