How to shuffle an iterable dataset

I think the standard approach to shuffling an iterable dataset is to introduce a shuffle buffer into your pipeline. Here’s the class I use to shuffle an iterable dataset:

class ShuffleDataset(torch.utils.data.IterableDataset):
  def __init__(self, dataset, buffer_size):
    super().__init__()
    self.dataset = dataset
    self.buffer_size = buffer_size

  def __iter__(self):
    shufbuf = []
    try:
      dataset_iter = iter(self.dataset)
      for i in range(self.buffer_size):
        shufbuf.append(next(dataset_iter))
    except:
      self.buffer_size = len(shufbuf)

    try:
      while True:
        try:
          item = next(dataset_iter)
          evict_idx = random.randint(0, self.buffer_size - 1)
          yield shufbuf[evict_idx]
          shufbuf[evict_idx] = item
        except StopIteration:
          break
      while len(shufbuf) > 0:
        yield shufbuf.pop()
    except GeneratorExit:
      pass

You’d wrap your existing iterable dataset like this:

dataset = MyIterableDataset()
dataset = ShuffleDataset(dataset, 1024)  # shuffle buffer size depends on your application
11 Likes