Tensorflow-esque bucket by sequence length

I’ve took the liberty to make some slight modifications since it was running out of the box; see full code below. The changes in a nutshell:

  • Since it’s a sampler, I renamed it to BucketBatchSampler :). Previously BucketDataset was an crude mix of Dataset and Sampler. Your approach is much cleaner.
  • Since it’s now only a sampler, there’s no need to keep self.inputs and self.targets any longer. Saves a lot of memory.
  • In your code batch_count() and __len__() where no longer working since there’s no self.batch_list and self.lengths anymore. I’ve updated this.
  • I’ve added a “proper” BucketDataset class to implements the torch.utils.data.Dataset

The usage now looks like this:

X = <data as np.array>
bucket_batch_sampler = BucketBatchSampler(X, BATCH_SIZE) # <-- does not store X
bucket_dataset = BucketDataset(X, None)
dataloader = DataLoader(bucket_dataset, batch_size=1, batch_sampler=bucket_batch_sampler, shuffle=False, num_workers=8, drop_last=False)

The full code:

from torch.utils.data import Sampler, Dataset
from collections import OrderedDict
from random import shuffle


class BucketDataset(Dataset):

    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        if self.targets is None:
            return self.inputs[index]
        else:
            return self.inputs[index], self.targets[index]


class BucketBatchSampler(Sampler):
    # want inputs to be an array
    def __init__(self, inputs, batch_size):
        self.batch_size = batch_size
        ind_n_len = []
        for i, p in enumerate(inputs):
            ind_n_len.append((i, p.shape[0]))
        self.ind_n_len = ind_n_len
        self.batch_list = self._generate_batch_map()
        self.num_batches = len(self.batch_list)

    def _generate_batch_map(self):
        # shuffle all of the indices first so they are put into buckets differently
        shuffle(self.ind_n_len)
        # Organize lengths, e.g., batch_map[10] = [30, 124, 203, ...] <= indices of sequences of length 10
        batch_map = OrderedDict()
        for idx, length in self.ind_n_len:
            if length not in batch_map:
                batch_map[length] = [idx]
            else:
                batch_map[length].append(idx)
        # Use batch_map to split indices into batches of equal size
        # e.g., for batch_size=3, batch_list = [[23,45,47], [49,50,62], [63,65,66], ...]
        batch_list = []
        for length, indices in batch_map.items():
            for group in [indices[i:(i + self.batch_size)] for i in range(0, len(indices), self.batch_size)]:
                batch_list.append(group)
        return batch_list

    def batch_count(self):
        return self.num_batches

    def __len__(self):
        return len(self.ind_n_len)

    def __iter__(self):
        self.batch_list = self._generate_batch_map()
        # shuffle all the batches so they arent ordered by bucket size
        shuffle(self.batch_list)
        for i in self.batch_list:
            yield i
9 Likes