I am trying to setup a simple RNN using LSTM. I am facing issue with passing the hidden state of RNN from one batch to another. For example, if my batch_size = 64, and I am using batch_first = True, hidden_size = 100 and nlayers = 3.
Then the shape of my hidden state is [3, 64, 100]. Now if my last batch is smaller than the batch_size 64, say 55, then I get the following error because the hidden state from the second last batch was of size [3, 64, 100].
You need to reset/initialize your hidden state before each forward pass, in this step you can adjust it to your current batch size. The most common pattern is to have a method in your model class like:
Note that this one is already flexible enough to support different types of RNNs, different number of layers, and whether the RNN layer is bidirectional or not. The method can be much simple if you don’t need this.
Assuming that you call forward() with a parameter inputs that contains the current batch with a shape (batch_size, ...). You can call self.init_hidden(inputs[0]) as first statement in your forward() method – or before the forward() method in the loop that iterates over the batches.
Thanks, this helps. I actually also want to learn the initial hidden state sent to the LSTM by wrapping self.hidden in nn.Parameter. In that case, I don’t think it would be right to reinit with zeros each time, right?
You can re-init the hidden state with 0s each time; it’s a common practice. For learning a (maybe) better initial hidden state, you can have a look at this earlier post.
@sahni
I tried to use your code of passing hidden states accross batches.
I see that i will face the last batch issue as well. But let’s keep that aside for now.
I tried to debug using pycharm. And I observed that self.hidden was getting reinitialized to None everytime.
Batch 1 : Starts with None , than changes according to (h,c)
Batch 2 : Starts with None again :? How is this possible ?
There’s nothing inherently special about 0s, it’s merely a way to represent “no prior knowledge”.
For example, when you want to classify individual sentences, your batches are completely independent – that is, the outcome for the sentences in Batch 2 should not depend in the outcome of Batch 1. Of course, if you’re sentences depend on each other, then you don’t want to re-initialize for each batch.
You don’t have to re-initialize to 0s anyway. You can always try without (you probably still have to detach it after each batch) and see how it effects the results. You can also re-initialize with random values for each batch.
So, if i don’t want to reinitialize after each batch, because my data is sequential, but the last batch in train and validation are with different size, how do i proceed with a stateful lstm?