Hi,
I’m working on sequence data and would like to group sequences of similar lengths into batches. Here is an example implementation (source)
"""
To group the texts with similar length together, like introduced in the legacy BucketIterator class,
first of all, we randomly create multiple "pools", and each of them has a size of batch_size * 100.
Then, we sort the samples within the individual pool by length.
This idea can be implemented succintly through batch_sampler argument of PyTorch Dataloader.
batch_sampler accepts 'Sampler' or Iterable object that yields indices of next batch.
In the code below, we implemented a generator that yields batch of indices for which the corresponding batch of data is of similar length.
"""
import random
train_iter = IMDB(split='train')
train_list = list(train_iter)
batch_size = 8 # A batch size of 8
def batch_sampler():
indices = [( i, len(tokenizer(s[1])) ) for i, s in enumerate(train_list)]
random.shuffle(indices)
pooled_indices = []
# create pool of indices with similar lengths
for i in range(0, len(indices), batch_size * 100):
pooled_indices.extend(sorted(indices[i:i + batch_size * 100], key=lambda x: x[1]))
pooled_indices = [x[0] for x in pooled_indices]
# yield indices for current batch
for i in range(0, len(pooled_indices), batch_size):
yield pooled_indices[i:i + batch_size]
bucket_dataloader = DataLoader(train_list, batch_sampler=batch_sampler(),
collate_fn=collate_batch)
I’m wondering what should I do to adapt it to a distributed one.
Thanks