How to use DataLoader with IterableDataset

For an infinite streaming dataset, how about something like this:

class DataStream1(IterableDataset):

    def __init__(self) -> None:
        super().__init__()
        self.size_input = 4
        self.size_output = 2

    def generate(self):
        while True:
            x = torch.rand(self.size_input)
            y = torch.rand(self.size_output)
            yield x, y

    def __iter__(self):
        return iter(self.generate())

dataset = DataStream1()

train_loader = DataLoader(dataset=dataset)

for i, data in enumerate(train_loader):
    print (i, data)
1 Like