Tensorflow-esque bucket by sequence length

Good question!

Essentially, it’s just for convenience; the model is agnostic to the sequence lengths. At least for training, as the lengths of the output sequences are known. But you’re right, predicting only works 1 sequences at a time and no longer in batches since the output sequence are likely to differ in lengths.

I actually implemented bucketing for Seq2Seq and use it all the time…again, simply for convenience and performance.

I don’t see how the output lengths would be known even during training? You might* know the target lengths, but the actual output is purely up to the model (though you’d probably enforce some max length), especially when you’re just starting training and the model is clueless, just any length would be possible.

*might : as sometimes there are more than one possible target, as in translation, where there might be more than one possible “correct” translation, which possibly vary in length as well, so you wouldn’t know even what the target length is.

Well, the model doesn’t “know” the length of the output sequence. But training a single sample is done when the decoder reaches the end-of-sequence (EOS) token or when there is not next token. So the model only knows the end when it has reached it…again, I’m referring only to the training.

The usage of a max_length or similar is only relevant for predicting sequences since it can always happen that the next predicted word is not EOS which would tell the decoder to stop. But training ends when the length of target sequences is reached. This is the basic setup for Seq2Seq with batch sizes of 1.

The only difference of bucketing is, that all target sequences are of the same length and the model reaches the EOS token for all samples in the batch at the same time.

I am sorry for bumping in too and I apologize for this not being related to Pytorch. I have been scratching my head trying to implement this exactly in Tensorflow.Keras. Would anyone here know how to approach this in Tensorflow.Keras? I have my data in:

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1) 

train_data = tf.data.Dataset.from_generator(lambda: (X_train, y_train),  output_types=(tf.float64, tf.float64),  output_shapes=(tf.TensorShape([None, 50]), tf.TensorShape([None, 2])))

val_data = tf.data.Dataset.from_generator(lambda: (X_test, y_test), output_types=(tf.float64, tf.float64),  output_shapes=(tf.TensorShape([None, 50]), ([None, 2])) )

This solution is OK when use Dataset. Your code requires the size of the dataset, but I have a IterableDataset, whose size is unknown. So how can I implement bucket_by_sequence_length?

If anyone is interested, I’ve implemented an simplified version. It works the same way, but also support Seq2Seq datasets, i.e., where the inputs and targets are sequences. This means that each batch will only contain sample with the same combination of input and target length to make any padding unnecessary.

import numpy as np
from torch.utils.data import Dataset, Sampler


class BaseDataset(Dataset):

    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        if self.targets is None:
            return np.asarray(self.inputs[index])
        else:
            return np.asarray(self.inputs[index]), np.asarray(self.targets[index])
                
        
class EqualLengthsBatchSampler(Sampler):

    def __init__(self, batch_size, inputs, targets):
        
        # Throw an error if the number of inputs and targets don't match
        if targets is not None:
            if len(inputs) != len(targets):
                raise Exception("[BucketBatchSampler] inputs and targets have different sizes")
        
        # Remember batch size and number of samples
        self.batch_size, self.num_samples = batch_size, len(inputs)
        
        self.unique_length_pairs = set()
        self.lengths_to_samples = {}
        
        for i in range(0, len(inputs)):
            len_input = len(inputs[i])
            try:
                # Fails if targets[i] is not a sequence but a scalar (e.g., a class label)
                len_target = len(targets[i])
            except:
                # In case of failure, we just the length to 1 (value doesn't matter, it only needs to be a constant)
                len_target = 1

            # Add length pair to set of all seen pairs
            self.unique_length_pairs.add((len_input, len_target))
        
            # For each lengths pair, keep track of which sample indices for this pair
            # E.g.: self.lengths_to_sample = { (4,5): [3,5,11], (5,5): [1,2,9], ...}
            if (len_input, len_target) in self.lengths_to_samples:
                self.lengths_to_samples[(len_input, len_target)].append(i)
            else:
                self.lengths_to_samples[(len_input, len_target)] = [i]
        
        # Convert set of unique length pairs to a list so we can shuffle it later
        self.unique_length_pairs = list(self.unique_length_pairs)
        
    def __len__(self):
        return self.num_samples
    
    def __iter__(self):

        # Shuffle list of unique length pairs
        np.random.shuffle(self.unique_length_pairs)
        
        # Iterate over all possible sentence length pairs
        for length_pair in self.unique_length_pairs:
            
            # Get indices of all samples for the current length pairs
            # for example, all indices with a lenght pair of (8,7)
            sequence_indices = self.lengths_to_samples[length_pair]
            sequence_indices = np.array(sequence_indices)
            
            # Shuffle array of sequence indices
            np.random.shuffle(sequence_indices)

            # Compute the number of batches
            num_batches = np.ceil(len(sequence_indices) / self.batch_size)

            # Loop over all possible batches
            for batch_indices in np.array_split(sequence_indices, num_batches):
                yield np.asarray(batch_indices)

I just implemented it for a tutorial, but maybe it’s useful for others.

2 Likes