How to implement a custom distributed sampler


I’m working on sequence data and would like to group sequences of similar lengths into batches. Here is an example implementation (source)

To group the texts with similar length together, like introduced in the legacy BucketIterator class,
first of all, we randomly create multiple "pools", and each of them has a size of batch_size * 100.
Then, we sort the samples within the individual pool by length. 
This idea can be implemented succintly through batch_sampler argument of PyTorch Dataloader. 
batch_sampler accepts 'Sampler' or Iterable object that yields indices of next batch. 
In the code below, we implemented a generator that yields batch of indices for which the corresponding batch of data is of similar length.
import random

train_iter = IMDB(split='train')
train_list = list(train_iter)
batch_size = 8  # A batch size of 8

def batch_sampler():
    indices = [( i, len(tokenizer(s[1])) ) for i, s in enumerate(train_list)]
    pooled_indices = []
    # create pool of indices with similar lengths 
    for i in range(0, len(indices), batch_size * 100):
        pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))

    pooled_indices = [x[0] for x in pooled_indices]

    # yield indices for current batch
    for i in range(0, len(pooled_indices), batch_size):
        yield pooled_indices[i:i + batch_size]

bucket_dataloader = DataLoader(train_list, batch_sampler=batch_sampler(),

I’m wondering what should I do to adapt it to a distributed one.


What you can do is create a class that inherits from DistributedSampler and modify it so that it behaves to what you need.

from import DistributedSampler

class MyDistributedBatchSampler(DistributedSampler):
    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        super().__init__(self, dataset, num_replicas, rank, shuffle,
                         seed, drop_last)
        # Initialize your stuff

    def __iter__(self) -> Iterator[List[int]]:
        # Here you need to look how DistributedSampler implements __iter__
        # Then you can modify it to serve your purposes

    # If len stays the same you can leave it out, else you can also modify it
    #def __len__(self):
        #return self.num_samples

Hi Matias,

Thanks for help.
I just found I shouldn’t implement my own Sampler since I’m dealing with Iterable datasets, as suggested in the doc here.

1 Like