Using DistributedSampler in combination with batch_sampler to make sure batches have sentences of similar length

With torchtext 0.9.0, BucketIterator was depreciated and DataLoader is encouraged to be used instead, which is great since DataLoader is compatible with DistributedSampler and hence DDP. However, it has a downside of not having the out-of-the-box implementation of having batches of similar length. The migration tutorial recommends using batch_sampler argument of DataLoader to pool together batches of similar length. Unfortunately, the batch_sampler is not compatible with the sampler.

Does anyone have any suggestions/ideas on how to make sure that when using DDP, the batches are of as similar length as possible?

1 Like

I’ve solved the problem by hacking together the Distributed sampler and the batch_sampler function from the migration tutorial.
Note that this implementation is for a seq2seq problem and the dataset is a tuple with 4 values: source seq, length of source seq, target seq and length of the target seq. It’s not important at all after the first lines.
I’ve abstracted the length sorting into a BatchSamplerSimilarLength Sampler class which can be used without any DDP.

class BatchSamplerSimilarLength(Sampler):
    def __init__(self, dataset, batch_size, indices=None, shuffle=True):
        self.batch_size = batch_size
        self.shuffle = shuffle
        # get the indicies and length
        self.indices = [(i, src_len) for i, (src, src_len, trg, trg_len) in enumerate(dataset)]
        # if indices are passed, then use only the ones passed (for ddp)
        if indices is not None:
            self.indices = torch.tensor(self.indices)[indices].tolist()

    def __iter__(self):
        if self.shuffle:
            random.shuffle(self.indices)

        pooled_indices = []
        # create pool of indices with similar lengths
        for i in range(0, len(self.indices), self.batch_size * 100):
            pooled_indices.extend(sorted(self.indices[i:i + self.batch_size * 100], key=lambda x: x[1]))
        self.pooled_indices = [x[0] for x in pooled_indices]

        # yield indices for current batch
        batches = [self.pooled_indices[i:i + self.batch_size] for i in
                   range(0, len(self.pooled_indices), self.batch_size)]

        if self.shuffle:
            random.shuffle(batches)
        for batch in batches:
            yield batch

    def __len__(self):
        return len(self.pooled_indices) // self.batch_size

Then I’ve created a child of the Distributed Sampler that overrides the iter method by passing the indices computed by the Distributed Sampler to the BatchSampler, so the Batch Sampler has access only to the indices at the current rank and then can proceed as normal (by batching by length).

class DistributedBatchSamplerSimilarLength(DistributedSampler):
    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False, batch_size = 10) -> None:
        super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed,
                         drop_last=drop_last)
        self.batch_size = batch_size

    def __iter__(self):
        indices = list(super().__iter__())
        batch_sampler = BatchSamplerSimilarLength(self.dataset, batch_size=self.batch_size, indices=indices)
        return iter(batch_sampler)

    def __len__(self) -> int:
        return self.num_samples//batch_size

The DistributedBatchSamplerSimilarLength has to be passed as a batch_sampler argument to the DataLoader.