There is the following code in pytorch LSTM word language model example:
# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
hidden = repackage_hidden(hidden)
model.zero_grad()
output, hidden = model(data, hidden)
loss = ... etc
( https://github.com/pytorch/examples/blob/master/word_language_model/main.py )
It works good if all sequences in a batch have the same length, but what if batch consists of sequences of different length stacked together, e.g. (let’s assume that I am trainig an LSTM to learn some character sequences in the order they are defined in the alphabet):
let’s consider batch (batch size 2 here) consisting of the following timesteps that is fed into LSTM during training:
[ [a, b, c, x y z], [e, f, g, h, q r] ]
And I want my model to learn that b goes after a, c goes after b, but not x after c, the model should learn “a, b, c”, “x, y, z”, “e, f, g, h”, “q, r”, so I need to detach the hidden state of timestep “x” in the first element of the batch from timestep “c” and timestep “q” in the second element of the batch from timestep “h”
I want to reinitialize the hidden state of the LSTM for batch element one to zeros at step 4 (“x”) and detach it from previous hidden state, and do the same for batch element 2 but at another timestep (5 - “q”)
Is it possible?