How to shuffle an iterable dataset

Hi,

I am using the IterableDataset class in order to avoid loading the whole data to memory. However, I cannot shuffle the dataset in that case. Is there any trick for shuffling an iterable dataset?

Thanks in advance!

1 Like

No.
The typical thing people are doing is to use a conventional dataset that will load data in the __getitem__ method.
Look at TorchVision’s DatasetFolder-dataset (used via ImageFolder for Imagenet) for inspiration.

The deeper reason is that suffling requires, by definition, “random access” of some sort. That’s what classic datasets are for.

Best regards

Thomas

1 Like

Thanks!
But in that sense, do you think that there are neural networks that can be trained using a non-shuffled dataset so that iterable dataset class can be applied? Because AFAIK any neural network training with mini-batches requires to shuffle the dataset.

Well, so it depends a bit. There are some things like language models where the text is decidedly not shuffled. It probably is not too good to feed a sorted (by categories) dataset into a classification network, but quite likely, it is not always necessary to have completely random order.
That said, I’d probably use a classic dataset unless you know you cannot use it (i.e. take the common route if you can).

Best regards

Thomas

1 Like

also there are cases where your data is generated / sampled on the fly: from distribution, from simulation, where iterable datasets are more suitable.

1 Like

I think the standard approach to shuffling an iterable dataset is to introduce a shuffle buffer into your pipeline. Here’s the class I use to shuffle an iterable dataset:

class ShuffleDataset(torch.utils.data.IterableDataset):
  def __init__(self, dataset, buffer_size):
    super().__init__()
    self.dataset = dataset
    self.buffer_size = buffer_size

  def __iter__(self):
    shufbuf = []
    try:
      dataset_iter = iter(self.dataset)
      for i in range(self.buffer_size):
        shufbuf.append(next(dataset_iter))
    except:
      self.buffer_size = len(shufbuf)

    try:
      while True:
        try:
          item = next(dataset_iter)
          evict_idx = random.randint(0, self.buffer_size - 1)
          yield shufbuf[evict_idx]
          shufbuf[evict_idx] = item
        except StopIteration:
          break
      while len(shufbuf) > 0:
        yield shufbuf.pop()
    except GeneratorExit:
      pass

You’d wrap your existing iterable dataset like this:

dataset = MyIterableDataset()
dataset = ShuffleDataset(dataset, 1024)  # shuffle buffer size depends on your application
10 Likes

The algorithm described above is now implemented natively in Pytorch, as BufferedShuffleDataset, documented here.

3 Likes

Is the BufferedShuffleDataset gone ? Or may be replaced by another class ? I can’t find it in the latest doc.

1 Like

Yeah, I’m not sure where it went. I don’t see any issues on github about removing it. It shows up in the 1.8.0 docs, but it’s gone in the current version.

It seems the removal wasn’t documented and no deprecation warning was used.
I could find this and this issue, which seems to point towards

torch.utils.data.datapipes.iter.combinatorics.ShuffleIterDataPipe

as its replacement.

4 Likes