I’m working on sequence data and would like to group sequences of similar lengths into batches. Here is an example implementation (source)
"""
To group the texts with similar length together, like introduced in the legacy BucketIterator class,
first of all, we randomly create multiple "pools", and each of them has a size of batch_size * 100.
Then, we sort the samples within the individual pool by length.
This idea can be implemented succintly through batch_sampler argument of PyTorch Dataloader.
batch_sampler accepts 'Sampler' or Iterable object that yields indices of next batch.
In the code below, we implemented a generator that yields batch of indices for which the corresponding batch of data is of similar length.
"""
import random
train_iter = IMDB(split='train')
train_list = list(train_iter)
batch_size = 8 # A batch size of 8
def batch_sampler():
indices = [( i, len(tokenizer(s[1])) ) for i, s in enumerate(train_list)]
random.shuffle(indices)
pooled_indices = []
# create pool of indices with similar lengths
for i in range(0, len(indices), batch_size * 100):
pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))
pooled_indices = [x[0] for x in pooled_indices]
# yield indices for current batch
for i in range(0, len(pooled_indices), batch_size):
yield pooled_indices[i:i + batch_size]
bucket_dataloader = DataLoader(train_list, batch_sampler=batch_sampler(),
collate_fn=collate_batch)
I’m wondering what should I do to adapt it to a distributed one.
What you can do is create a class that inherits from DistributedSampler and modify it so that it behaves to what you need.
from torch.utils.data.distributed import DistributedSampler
class MyDistributedBatchSampler(DistributedSampler):
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
rank: Optional[int] = None, shuffle: bool = True,
seed: int = 0, drop_last: bool = False) -> None:
super().__init__(self, dataset, num_replicas, rank, shuffle,
seed, drop_last)
# Initialize your stuff
def __iter__(self) -> Iterator[List[int]]:
# Here you need to look how DistributedSampler implements __iter__
# https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler
# Then you can modify it to serve your purposes
# If len stays the same you can leave it out, else you can also modify it
#def __len__(self):
#return self.num_samples
just adding on top of that. I recently realized that we have to derive from DistributedSampler here, rather than normal sampler, or data sharding will happen 2 times in the mutl-gpu training.
The sampler creates indices based on its rank and thus loads only the corresponding samples. Let me know if this clarifies it or if a link to the source might be helpful (I’m on my phone now otherwise I would have posted it directly).