I have a very peculiar use case where some sequences in a batch start “later” than others. For example, consider the following batch:
t | 0 | 1 | 2 | 3 | 4 |
0| a | b | c | d | e |
1|==| g | h | k | m|
2|==|==|==| n | o|
Here, the 0’th sequence starts at time 0, first at time 1 and 2nd at time 3. The corresponding PackedSequence would be something like PackedSequence(data=[a, b, c, d, e, g, h, k, m, n, o], batch_sizes=[1, 2, 2, 3, 3]). Is there a good way to represent this “left-padded” batch as a PackedSequence and pass it to an RNN, e.g. LSTM and GRU? Or is the only way to left pad the batch manually and mask the outputs (which is not ideal b/c the initial hidden states wouldn’t be zeros for sequences that start later)?
I hope I was clear enough. If not, please let me know so I can clarify. Thanks for your help in advance!
Do you have all of the sequences in hand at one time?
If so, I’d think you could do something like this in your collate function for your dataloader:
sequences.sort(key=lambda x: len(x), reverse=True)
lengths = [len(sequence) for sequence in text_list]
reversed_sequences = sequences[::-1]
pad_sequence(reversed_sequences, batch_first=True, padding_value=vocab["<pad>"])
Basically reverse, pad, reverse…
Then you’d have to mask your network outputs appropriately when measuring loss.
I didn’t test the above, but I don’t see why it would not work assuming the rest of your model/training/loss was set up right.
Thanks for the reply Wesley! Your code basically left pads the batch right? It’d work, except the shorter sequences that start later would not receive zeros as their initial hidden states (b/c the RNN would’ve already processed the padding), which PackedSequence would solve. I guess this doesn’t really matter in the end?
Yeah, that should left-pad the batch. The shorter sequences should receive zeros.
As far as what would happen with the hidden state, I’m having a hard time visualizing it without actually playing with it. If I were you, I’d probably just try to run a single batch through a test RNN layer and see what comes out the other side.