Possible to detach lstm hidden state only for a certain batch element at certain timestep?

There is the following code in pytorch LSTM word language model example:

# Starting each batch, we detach the hidden state from how it was previously produced.
# If we didn't, the model would try backpropagating all the way to start of the dataset.
hidden = repackage_hidden(hidden)
output, hidden = model(data, hidden)
loss = ... etc

( https://github.com/pytorch/examples/blob/master/word_language_model/main.py )

It works good if all sequences in a batch have the same length, but what if batch consists of sequences of different length stacked together, e.g. (let’s assume that I am trainig an LSTM to learn some character sequences in the order they are defined in the alphabet):

let’s consider batch (batch size 2 here) consisting of the following timesteps that is fed into LSTM during training:

[ [a, b, c, x y z], [e, f, g, h, q r] ]

And I want my model to learn that b goes after a, c goes after b, but not x after c, the model should learn “a, b, c”, “x, y, z”, “e, f, g, h”, “q, r”, so I need to detach the hidden state of timestep “x” in the first element of the batch from timestep “c” and timestep “q” in the second element of the batch from timestep “h”

I want to reinitialize the hidden state of the LSTM for batch element one to zeros at step 4 (“x”) and detach it from previous hidden state, and do the same for batch element 2 but at another timestep (5 - “q”)

Is it possible?

1 Like

Why can’t you just split it into 4 sequences?

1 Like

They would be of different length if i split them, I know I can use padding, but that would hurt model’s quality to some extent…

IIRC, PyTorch RNN modules won’t activate past the sequence lengths if you pass it in as a packed sequence. http://pytorch.org/docs/0.2.0/nn.html#torch.nn.utils.rnn.pack_padded_sequence

Did you look into that as well? It shouldn’t hurt performance, right?..

Great! thats what I was looking for, thank you very much.

1 Like