Why is Hidden Variable out of Network Class in Pytorch examples Language Model?

Hi

I am curious about the implementation here: https://github.com/pytorch/examples/tree/master/word_language_model in the main.py file.

of word level language model.

> hidden = model.init_hidden(args.batch_size)
>     for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
>         data, targets = get_batch(train_data, i)
>         # Starting each batch, we detach the hidden state from how it was previously produced.
>         # If we didn't, the model would try backpropagating all the way to start of the dataset.
>         hidden = repackage_hidden(hidden)
>         model.zero_grad()
>         output, hidden = model(data, hidden)

I think since a network already handles inputs in a batch_size, why create an external hidden variable. The hidden layers or hidden variables should be a part of the network class iteself.

So, why keep it outside (why not just have a hidden variable inside the forward function of the RNNModel)?
What benefits do we get and is there any other way to do a cleaner implementation?

A training loop that looked like this would be nicer.

for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
    data, targets = get_batch(train_data, i)
    model.zero_grad()
    output = model(data)
    loss = criterion(output.view(-1, ntokens), targets)
    loss.backward()

But for that model.forward would have to look something like this

 def forward(self, input):
     ...
     if self.hidden is None or batch_size_doesnt_match_hidden_size:
         self.hidden = self.init_hidden()
     output, self.hidden = self.rnn(emb, self.hidden.detach())
     ...
     return just_the_output

How about adding a keep_hidden=True parameter to the LSTM class which would ask it to manage its hidden state internally?

1 Like

thanks! will try this way… could you explain a bit more… get_batch(train_data, i) will have i in increments of bptt… why does bptt come here?

The training loop repackages the hidden state before each batch of training data which means that bptt (back propagation through time) is cut off at the beginning of each batch.

So if you want training to be able to back propagate through 50 timesteps then each batch must contain at least 50 timesteps.

1 Like