How to use DataLoader with IterableDataset

I would like to use IterableDataset to create an infinite dataset that I can pass to DataLoader.

I tried two approaches and would like to know which one should be preferred or if there is a better solution for an infinite stream of data in Pytorch.

I also have the problem that changing the batch_size argument has no effect. What do I do wrong here?

What about other parameters such as pin_memory, shuffle etc? Can I simply ignore them when working with an infinite dataset?

Here is a minimal working example:

import torch
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import IterableDataset


class DataStream1(IterableDataset):

    def __init__(self) -> None:
        super().__init__()
        self.size_input = 4
        self.size_output = 2

    def generate(self):
        x = torch.rand(self.size_input) 
        y = torch.rand(self.size_output) 
        yield x, y 

    def __iter__(self):
        return iter(self.generate())


class Stream:

    def __init__(self):
        self.size_input = 4
        self.size_output = 2

    def __next__(self):
        x = torch.rand(self.size_input) 
        y = torch.rand(self.size_output) 
        return x, y 


class DataStream2(IterableDataset):

    def __init__(self) -> None:
        super().__init__()

    def __iter__(self):
        return Stream() 


def test(data_set):
    train_loader = DataLoader(
        dataset=data_set,
        pin_memory=False,
        shuffle=False,
        batch_size=2,
        num_workers=1,
    )
    for epoch in range(2):
        print(f"\n{epoch = }")
        for x, y in train_loader:
            print(f"{x = }")
            print(f"{y = }")


def main():

    # Ignores batch size, but stops after two epochs.
    data_stream_1 = DataStream1()
    test(data_stream_1)

    # Ignores batch size, does exit the train_loader loop.
    data_stream_2 = DataStream2()
    test(data_stream_2)


if __name__ == "__main__":
    main()

For an infinite streaming dataset, how about something like this:

class DataStream1(IterableDataset):

    def __init__(self) -> None:
        super().__init__()
        self.size_input = 4
        self.size_output = 2

    def generate(self):
        while True:
            x = torch.rand(self.size_input)
            y = torch.rand(self.size_output)
            yield x, y

    def __iter__(self):
        return iter(self.generate())

dataset = DataStream1()

train_loader = DataLoader(dataset=dataset)

for i, data in enumerate(train_loader):
    print (i, data)
1 Like

Thank you for your response.

I don’t see how your answer is different from my DataStream1 class. I also would like to know how I can define a batch size? Maybe you can say something more about that using your approach. Thank you.

1 Like

In the generate function, I have added a while True condition which can yield an infinite stream of data. To have a batch size, it can be updated as:

    def generate(self):
        while True:
            x = torch.rand(self.batch_size, self.size_input)
            y = torch.rand(self.batch_size, self.size_output)
            yield x, y