I have put @vdw’s bucketer by length which removes any need for padding(!) into a BatchSampler object and introduced shuffling of the data and buckets to improve convergence while training. The perk of it being a BatchSampler object is that you can pass it into a DataLoader and parallelize inputting data into the GPU.
The BatchSampler:
(Here I am using an autoencoder so I don’t input targets
to the function or yield it)
from collections import OrderedDict
import numpy as np
from random import shuffle
class BucketDataset(Sampler):
# want inputs to be an array
def __init__(self, inputs, batch_size):
self.inputs = inputs # shape = (N, max_seq_len)
self.targets = targets # shape = (N, ) or None (e.g., for autoencoder I can simply use inputs)
self.batch_size = batch_size
ind_n_len = []
for i, p in enumerate(inputs):
ind_n_len.append( (i, p.shape[0]) )
self.ind_n_len = ind_n_len
def _generate_batch_map(self):
shuffle(self.ind_n_len) # shuffle all of the indices first so they are put into buckets differently
batch_map = OrderedDict()
# Organize lengths, e.g., batch_map[10] = [30, 124, 203, ...] <= indices of sequences of length 10
for idx, length in self.ind_n_len:
if length not in batch_map:
batch_map[length] = [idx]
else:
batch_map[length].append(idx)
# Use batch_map to split indices into batches of equal size
# e.g., for batch_size=3, batch_list = [[23,45,47], [49,50,62], [63,65,66], ...]
batch_list = []
for length, indices in batch_map.items():
for group in [indices[i:(i+self.batch_size)] for i in range(0, len(indices), self.batch_size)]:
batch_list.append(group)
return batch_list
def batch_count(self):
return len(self.batch_list)
def __len__(self):
return len(self.lengths)
def __iter__(self):
batch_list = self._generate_batch_map()
shuffle(batch_list) # shuffle all the batches so they arent ordered by bucket size
for i in batch_list:
yield i
Calling the BatchSampler and DataLoader:
sampler = BucketDataset(<your data in an np.array>, BATCH_SIZE)
dataloader = DataLoader(<your data as a DataSet object>, batch_size=1,
batch_sampler=sampler, shuffle=False,
num_workers=8, drop_last=False)