Handling the hidden state with minibatches in a RNN for language modelling

Hi,
I don’t understand how to handle the hidden state when passing minibatches of sentences into my RNN.

In my case the input data to the model is a minibatch of N sentences with varying length. Each sentence consist of word indices representing a word in the vocabulary:

sents = [[4, 545, 23, 1], [34, 84], [23, 6, 774]]

The sentences in the dataset are randomly shuffled before creating minibatches.
Here is how the minibatches are created:

def batches(data, batch_size):
    """ Yields batches of sentences from 'data', ordered on length. """
    random.shuffle(data)
    for i in range(0, len(data), batch_size):
        sentences = data[i:i + batch_size]
        sentences.sort(key=lambda l: len(l), reverse=True)
        yield [torch.LongTensor(s) for s in sentences]

The model predicts the next element in the sentence. So the input and target looks like this:

input_sentence = [1, 4, 5, 7]
target_sentence = [4, 5, 7, 9]

Packed sequences are used in order to handle sentences of varying length Here is how the input and target are created:

x = nn.utils.rnn.pack_sequence([s[:-1] for s in sents])
y = nn.utils.rnn.pack_sequence([s[1:] for s in sents])

This input x, consisting of a minibatch of sentences, is then sent through the forward pass of the model:

out = model(x)

The model itself:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    """ A language model RNN with GRU layer(s). """

    def __init__(self, vocab_size, embedding_dim, hidden_dim, gru_layers, dropout):
        super(Model, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.recurrent_layer = nn.GRU(input_size=embedding_dim, hidden_size=hidden_dim,  num_layers=gru_layers, dropout=dropout, bidirectional=False)
        self.fc1 = nn.Linear(hidden_dim, vocab_size)

    def forward(self, packed_sents):
        """ Takes a PackedSequence of sentences tokens that has T tokens
        belonging to vocabulary V. Outputs predicted log-probabilities
        for the token following the one that's input in a tensor shaped
        (T, |V|).
        """
        embedded_sents = nn.utils.rnn.PackedSequence(self.embedding(packed_sents.data), packed_sents.batch_sizes)
        out_packed_sequence, hidden = self.recurrent_layer(embedded_sents)
        out = self.fc1(out_packed_sequence.data)
        return F.log_softmax(out, dim=1)

From this example the hidden state is not explicitly passed from one forward pass into the next.

Question 1)
I don’t understand how to pass the hidden state from one time step to the next. I want the hidden state from item 1 in the sentence to be passed into the computation of item 2.
So for the sentence [1, 2, 3], how is the hidden state after estimating element 1 passed to the computation of element 2? Is this handled under the hood by pytorch when the minibatch sentences are fed into the recurrent layer? Like here:

 self.recurrent_layer(embedded_sents)

I have seen some examples that explicitly feed the previous hidden state into the recurrent layer in the current forward pass. For example:

out_packed_sequence, self.hidden = self.recurrent_layer(embedded_sents, self.hidden)

But I don’t think that would be correct in this situation because the input is a minibatch of sentences - and then I would be using the hidden state from the previous minibatch for the current minibatch. I only want to pass the hidden state from the current item to the next item within the sentence.

Question 2)
I think the hidden state should be reset after each new sentence. Because the dependencies I want to capture is between the items in each sentence, not between different sentences. Is there a way to reset the hidden state after each sentence when minibatches are used? Here the entire minibatch is fed into the model for each forward pass. Or do I need to use an approach like sending one and one sentence into the model for each forward pass. Then reset the hidden state after each sentence instead of using minibatches.

The code is adapted from this github repo:
https://github.com/florijanstamenkovic/PytorchRnnLM

Any help is appreciated.

Regarding (1): Yes, everything is handled under the hood of PyTorch. If embedded_sents in self.recurrent_layer(embedded_sents, self.hidden) is the batch of the complete sentences, self.hidden will be updated and re-used time step by time step. If for some reason you want to intervene at each time step, you can given self.recurrent only sequences of length 1.

Regarding (2): Yes, if your batches are independent, you generally reset self.hidden before calling foward() on the next batch. You can have a look for that in some code of mine for an RNN-based autoencoder. Check the method init_hidden(self, batch_size) and where it is called.

I hope that helps.

Thank you for your response @vdw. I think that clarifies (1).

Regarding (2):
Yes, the hidden state should be reset after each minibatch. In the first code snippet of the model, nothing is done with the hidden state between two different forward passes.

out_packed_sequence, hidden = self.recurrent_layer(embedded_sents)

I think this means that the hidden state is reset after each forward pass, and therefore reset after each minibatch.

However, I still wonder what happens with the hidden state inside one minibatch of sentences. Suppose we have the following minibatch with two sentences: [ [1, 2, 3] [4, 5, 6] ].

My question is:
Will the hidden state of the third item (3) from sentence 1 be passed into the computation of the first element (4) of sentence two?
If the answer to this is yes, can this be a problem? I ask this based on the fact that the dataset is randomly shuffled before creating minibatches. Therefore I believe the sentences inside each minibatch should not depend on each other.

thanks!

No, the sentences in a batch are processed independently and in parallel in one go. Otherwise there would be no need for the batch_size dimension in the shape of the hidden state.

I understand. Thank you for helping me. This answers my questions.