Not understanding how to make iterable dataset

I’m having a hard time understanding the example iterable dataset in the docs, and how to apply it to my own code. The example code is:

>>> class MyIterableDataset(torch.utils.data.IterableDataset):
...     def __init__(self, start, end):
...         super(MyIterableDataset).__init__()
...         assert end > start, "this example code only works with end >= start"
...         self.start = start
...         self.end = end
...
...     def __iter__(self):
...         worker_info = torch.utils.data.get_worker_info()
...         if worker_info is None:  # single-process data loading, return the full iterator
...             iter_start = self.start
...             iter_end = self.end
...         else:  # in a worker process
...             # split workload
...             per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
...             worker_id = worker_info.id
...             iter_start = self.start + worker_id * per_worker
...             iter_end = min(iter_start + per_worker, self.end)
...         return iter(range(iter_start, iter_end))
...

In the IterableDataset I have created it loops through a set of files (the data is already randomly shuffled, so that is not an issue), and yields one row at a time:

    def __iter__(self) -> Iterator[dict]:
        for file in self.files:
            if self.compressed:
                fopen = BZ2File(filename=file, mode='r')
            else:
                fopen = open(file=file, mode='r')

            with fopen as f:
                for row in f:
                    data = json.loads(row)
                    tokens = data['words']
                    tokens = tokens[:self.max_length]
                    random.shuffle(tokens)
                    indices, mask = self.tokens_to_indices(tokens)

                    item = {
                        'src': indices,
                        'mask': mask,
                        'label': data['label']
                    }

                    yield item

I’m not sure how it would make sense to return an iterator in this scenario. Any advice would be appreciated!

It’s fine to use yield inside the __iter__ function. To instantiate the dataset and use it, you can do something similar to the following:

dataset = MyIterableDataset(...)
dataset = iter(dataset)  # makes it actually iterable

for n_iteration in range(100):
   batch_item = next(dataset)
   # batch_item is a dict with "src", "mask", and "label"
   ...

Note that your dataset will eventually “end” since you’re doing a finite for-loop over self.files. If you want your IterableDataset to continue indefinitely, just wrap everything with a while-loop at the top, e.g.:

    def __iter__(self) -> Iterator[dict]:
        while True:
            for file in self.files:
                if self.compressed:
            ...

Hey @actuallyaswin thanks for your response! So about the last point you made that eventually the dataset will “end”, I found actually that isn’t the case. In fact it keeps iterating through batches even after it should have already exhausted all the data. But…perhaps this is a behavior for pytorch-lightning which I’m using to train the model.

On a related note it looks like I can take advantage of multiple workers simply using islice when iterating over the files:

        worker_total_num = torch.utils.data.get_worker_info().num_workers
        worker_id = torch.utils.data.get_worker_info().id
        for file in self.files:
            if self.compressed:
                fopen = BZ2File(filename=file, mode='r')
            else:
                fopen = open(file=file, mode='r')

            with fopen as f:
                for row in islice(f,worker_id,None,worker_total_num):