I’ve took the liberty to make some slight modifications since it was running out of the box; see full code below. The changes in a nutshell:
- Since it’s a sampler, I renamed it to
BucketBatchSampler
:). PreviouslyBucketDataset
was an crude mix ofDataset
andSampler
. Your approach is much cleaner. - Since it’s now only a sampler, there’s no need to keep
self.inputs
andself.targets
any longer. Saves a lot of memory. - In your code
batch_count()
and__len__()
where no longer working since there’s noself.batch_list
andself.lengths
anymore. I’ve updated this. - I’ve added a “proper”
BucketDataset
class to implements thetorch.utils.data.Dataset
The usage now looks like this:
X = <data as np.array>
bucket_batch_sampler = BucketBatchSampler(X, BATCH_SIZE) # <-- does not store X
bucket_dataset = BucketDataset(X, None)
dataloader = DataLoader(bucket_dataset, batch_size=1, batch_sampler=bucket_batch_sampler, shuffle=False, num_workers=8, drop_last=False)
The full code:
from torch.utils.data import Sampler, Dataset
from collections import OrderedDict
from random import shuffle
class BucketDataset(Dataset):
def __init__(self, inputs, targets):
self.inputs = inputs
self.targets = targets
def __len__(self):
return len(self.inputs)
def __getitem__(self, index):
if self.targets is None:
return self.inputs[index]
else:
return self.inputs[index], self.targets[index]
class BucketBatchSampler(Sampler):
# want inputs to be an array
def __init__(self, inputs, batch_size):
self.batch_size = batch_size
ind_n_len = []
for i, p in enumerate(inputs):
ind_n_len.append((i, p.shape[0]))
self.ind_n_len = ind_n_len
self.batch_list = self._generate_batch_map()
self.num_batches = len(self.batch_list)
def _generate_batch_map(self):
# shuffle all of the indices first so they are put into buckets differently
shuffle(self.ind_n_len)
# Organize lengths, e.g., batch_map[10] = [30, 124, 203, ...] <= indices of sequences of length 10
batch_map = OrderedDict()
for idx, length in self.ind_n_len:
if length not in batch_map:
batch_map[length] = [idx]
else:
batch_map[length].append(idx)
# Use batch_map to split indices into batches of equal size
# e.g., for batch_size=3, batch_list = [[23,45,47], [49,50,62], [63,65,66], ...]
batch_list = []
for length, indices in batch_map.items():
for group in [indices[i:(i + self.batch_size)] for i in range(0, len(indices), self.batch_size)]:
batch_list.append(group)
return batch_list
def batch_count(self):
return self.num_batches
def __len__(self):
return len(self.ind_n_len)
def __iter__(self):
self.batch_list = self._generate_batch_map()
# shuffle all the batches so they arent ordered by bucket size
shuffle(self.batch_list)
for i in self.batch_list:
yield i