Example for torch.utils.data.IterableDataset

I have an iterable dataset, then I want to write a dataloader for it, in tutorial, I only find this example:

which is not clear how to expand it for a real dataset. could you provide me an example where you are given an iterable dataset, and you can write a dataloader for it. thanks

1 Like

What it is that you want? please provide a bit of information concerning your usecase

I have a tf.data.Dataset dataset, I’d like to write an iterable Dataloader for pytroch on it, and I’d like to make sure this works with multiple workers, like this example https://pytorch.org/docs/stable/data.html#multi-process-data-loading . Also, I need to use it with pytroch lightening so it needs to work with PyTorch lightening as well.

This is example from pytorch tutorial, in this example in iter they write " iter(range(iter_start, iter_end))", but in the real scenario, one input an iterable to this class, and wants to iterate over it, could you tell me please how I can change this example when having an iterable dataset? in a way which allow it to work well with pytorch lightening and in a distributed way on tpu/gpu. thanks

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
>>> # should give same set of data as range(3, 7), i.e., [3, 4, 5, 6].
>>> ds = MyIterableDataset(start=3, end=7)

>>> # Single-process loading
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
[3, 4, 5, 6]

>>> # Mult-process loading with two worker processes
>>> # Worker 0 fetched [3, 4].  Worker 1 fetched [5, 6].
>>> print(list(torch.utils.data.DataLoader(ds, num_workers=2)))
[3, 5, 4, 6]

Maybe you could use yield?

class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, your_args):
...         super(MyIterableDataset).__init__()
...         # this depends on your dataset, suppose your dataset contains 
...         # images whose path you save in this list
...         self.dataset = # something for load the path of images 
...     def __iter__(self):
...         for image_path in self.dataset:
...             sample, label = read_single_example(image_path) # read an individual sample and its label
...             yield sample, label  
1 Like

Hi Ivan
Thank you, my question is when we have multiple workers, then based on example given in pytorch tutorial, each worker specify the start and end part of the data, and the data got split between multiple workers. Please have a look at __iter__ function, where iter_start and iter_end are specified for the workers. then the worker needs to iterate in this range, in the pytorch example given they wrote iter(range(iter_start, iter_end)) because they have a simple list as input, I am not sure how to do it with iterable datasets, thank you.

Here is the link to the tutorial also https://pytorch.org/docs/stable/data.html#multi-process-data-loading, where they explain that each worker needs to access different chunk of data.

Also, in case of pytorch lightening, does the user still need to specify worker_info ? thanks

Why don’t you simply turn your tensorflow dataset to a list (since its a iterable, you should be able to do so in a one liner) and then solve problem from there. That is simply do :

tf_lst = list(tf_dataset)

now you have a list which you can simply incorporate into a new pytorch dataset and do as you wish!

This is a large dataset, it is not possible to load it in memory, so this is an iterable one.

Cant you use the range() method of the dataset to get around this issue? give each owrker a chunk of data with different steps maybe?

how would you write it in case of iterable datasets?

Your tensorflow dataset has a range method where you can specify the start, stop and step,
your pytorch dataset, requires you to give a chunk of data to each worker. and you can identify each worker using workerinfo if I’m not mistaken!

consider the following example :

import tensorflow as tf 
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
dataset = dataset.map(lambda x: x*2)
worker1_data =dataset.range(0, 3, 1)
worker2_data = dataset.range(4, 8, 1)

and based on each worker you can use one of them.
you may also have a look at tf.data.Dataset.shard() method which was created for this very reason.

dataset.range does not work on a tf.data.datasets and replaces the dataset with range values.

Could you provide me with a full example showing how to use iterative datasets with multiple workers? thanks

here is the minimalist example to show what I mean, could you help me please complete it with different workers.

from torch.utils.data import Dataset, DataLoader
import torch
import tensorflow_datasets as tfds
import tensorflow as tf
import itertools
from itertools import cycle, islice

def get_dummy_dataset():
  inputs = ["input 1",
        "input 2",
        "input 3",
        "input 4"]
  target = ["target 1",
            "target 2",
            "target 3",
            "target 4"]
  features = {"inputs": inputs, "targets": target}
  def my_fn(features):
    ret = {}
    for k, v in features.items():
          ret[f'{k}_plaintext'] = v
    return ret
  dataset = tf.data.Dataset.from_tensor_slices(features)
  dataset = dataset.map(my_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  return dataset

class WMTDataset(torch.utils.data.IterableDataset):
    def __init__(self, batch_size):
        dataset = get_dummy_dataset()
        self.dataset_size = 4
        self.batch_size = batch_size
        self.dataset = self.create_dataset(dataset)

    def __len__(self):
      return self.dataset_size

    def __iter__(self):
      return self.dataset

    def create_dataset(self, dataset):
      dataset = dataset.batch(self.batch_size, drop_remainder=False)
      return itertools.cycle(dataset)

iterable_dataset = WMTDataset(batch_size=2)
loader = DataLoader(iterable_dataset, batch_size=None)
for batch in islice(loader, 2):
    print("#########batch ", batch)

I’m not sure if I understand what you want but you couldn’t generate a csv file with the path of all the samples in your dataset, then create a dataframe and assign 1/n samples to each worker?

here is an example of using iterative dataset with multiple workers, the trick is to use multiple workers: https://medium.com/speechmatics/how-to-build-a-streaming-dataloader-with-pytorch-a66dd891d9dd https://github.com/etienne87/pytorch-streamloader/blob/master/pytorch_iterable.py

1 Like