Batching with seq2seq model

I have a baseline seq2seq model with LSTM encoder and decoder. The encoder forwards the input sequence and the final hidden and cell state is presumed to contain the summary of the entire sequence. This summary is then fed as the initial hidden and cell state to the LSTM decoder which then generates the output sequence token by token.

I have batched sequences of different lengths by padding at the end (using torch.utils.rnn.pack_padded_sequence). Consider the case where a sequence has padding tokens towards the end. If we pass this sequence through the LSTM encoder, the final hidden/cell state retrieved will not be the real one as this state is reached after processing the padding tokens. Ideally, we would want the hidden state right after the last non-padding token was processed. How can this ideal situation be realized?

If you are using a packed_sequence e.g. from the pack_padded_sequence and pass it through the LSTM, the (h_n, c_n) that you obtain correspond to the last valid non-padded entry so you don’t have to worry about it. For example, see the discussion at link.

However, if you want to do it yourself, you can do the unpacking of the output and then access the corresponding last valid vectors/tensors using the length of unpacked tensors. E.g. from a packed sequence:

import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# create a BxTxD = 3x3x3 tensor with 
# first entry in batch having 2 valid Timesteps and third being a pad entry
# 2nd entry in batch being 1 valid timestep and the other two being pad entries
# and 3rd entry in batch with all valid timesteps
seq = torch.tensor([[[1,1,1],[2,2,2],[0,0,0]],[[3,3,3],[0,0,0],[0,0,0]],[[4,4,4],[5,5,5],[6,6,6]]])
lens = [2,1,3]
packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)

# to get the last valid entries per row in such a packed sequence as above, 
# unpack and extract corresponding entries
seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)

valid_last_entry = seq_unpacked[range(len(lens_unpacked)), lens_unpacked-1]
assert torch.all(valid_last_entry == torch.tensor([[2,2,2],[3,3,3],[6,6,6]]))
