I’ve solved the problem by hacking together the Distributed sampler and the batch_sampler function from the migration tutorial.
Note that this implementation is for a seq2seq problem and the dataset is a tuple with 4 values: source seq, length of source seq, target seq and length of the target seq. It’s not important at all after the first lines.
I’ve abstracted the length sorting into a BatchSamplerSimilarLength Sampler class which can be used without any DDP.
class BatchSamplerSimilarLength(Sampler):
def __init__(self, dataset, batch_size, indices=None, shuffle=True):
self.batch_size = batch_size
self.shuffle = shuffle
# get the indicies and length
self.indices = [(i, src_len) for i, (src, src_len, trg, trg_len) in enumerate(dataset)]
# if indices are passed, then use only the ones passed (for ddp)
if indices is not None:
self.indices = torch.tensor(self.indices)[indices].tolist()
def __iter__(self):
if self.shuffle:
random.shuffle(self.indices)
pooled_indices = []
# create pool of indices with similar lengths
for i in range(0, len(self.indices), self.batch_size * 100):
pooled_indices.extend(sorted(self.indices[i:i + self.batch_size * 100], key=lambda x: x[1]))
self.pooled_indices = [x[0] for x in pooled_indices]
# yield indices for current batch
batches = [self.pooled_indices[i:i + self.batch_size] for i in
range(0, len(self.pooled_indices), self.batch_size)]
if self.shuffle:
random.shuffle(batches)
for batch in batches:
yield batch
def __len__(self):
return len(self.pooled_indices) // self.batch_size
Then I’ve created a child of the Distributed Sampler that overrides the iter method by passing the indices computed by the Distributed Sampler to the BatchSampler, so the Batch Sampler has access only to the indices at the current rank and then can proceed as normal (by batching by length).
class DistributedBatchSamplerSimilarLength(DistributedSampler):
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
rank: Optional[int] = None, shuffle: bool = True,
seed: int = 0, drop_last: bool = False, batch_size = 10) -> None:
super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, seed=seed,
drop_last=drop_last)
self.batch_size = batch_size
def __iter__(self):
indices = list(super().__iter__())
batch_sampler = BatchSamplerSimilarLength(self.dataset, batch_size=self.batch_size, indices=indices)
return iter(batch_sampler)
def __len__(self) -> int:
return self.num_samples//batch_size
The DistributedBatchSamplerSimilarLength has to be passed as a batch_sampler
argument to the DataLoader.