Batching with seq2seq model

If you are using a packed_sequence e.g. from the pack_padded_sequence and pass it through the LSTM, the (h_n, c_n) that you obtain correspond to the last valid non-padded entry so you don’t have to worry about it. For example, see the discussion at link.

However, if you want to do it yourself, you can do the unpacking of the output and then access the corresponding last valid vectors/tensors using the length of unpacked tensors. E.g. from a packed sequence:

import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# create a BxTxD = 3x3x3 tensor with 
# first entry in batch having 2 valid Timesteps and third being a pad entry
# 2nd entry in batch being 1 valid timestep and the other two being pad entries
# and 3rd entry in batch with all valid timesteps
seq = torch.tensor([[[1,1,1],[2,2,2],[0,0,0]],[[3,3,3],[0,0,0],[0,0,0]],[[4,4,4],[5,5,5],[6,6,6]]])
lens = [2,1,3]
packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)

# to get the last valid entries per row in such a packed sequence as above, 
# unpack and extract corresponding entries
seq_unpacked, lens_unpacked = pad_packed_sequence(packed, batch_first=True)

valid_last_entry = seq_unpacked[range(len(lens_unpacked)), lens_unpacked-1]
assert torch.all(valid_last_entry == torch.tensor([[2,2,2],[3,3,3],[6,6,6]]))
1 Like