Infinite DataLoader

Is there a good way to have an infinite dataloader? That said, is there a class that will provide automatically looping for method like data_loader.get_next()? And how to maintain full iterations?

6 Likes

Well one quick and dirty hack would be for your CustomDataset to return a very high number (e.g. np.iinfo(np.int64).max) in its __len__ .

As for get_next(), you can get the iterator from the dataloader and call next on that:

next(dataloader.__iter__())

1 Like

Thanks! However, in this way, will the data sampling strategy be affected? That said, if I have 10 batches for the data, and I do next for 20 times, will that be 2 full iteration?

1 Like

Any thoughts?:sweat_smile:

Alternatively you should be able to just catch the StopIteration exception and reinitialize the generator (maybe with some previous shuffling)

2 Likes

Could you provide a basic example on how to do so? Thanks!!

Also it seems that every time I called next(dataloader.__iter__()), I will get the same data.

1 Like
next(data loader. __iter__()) 

Creates a new iterator starting at the beginning of the dataset every time you call it.

# create dataloader-iterator
data_iter = iter(data_loader) 

# iterate over dataset
# alternatively you could use while(True) 
for i in range(NUM_ITERS_YOU_WANT)
    try:
        data = next(data_iter) 
    except StopIteration:
        # StopIteration is thrown if dataset ends
        # reinitialize data loader 
        data_iter = iter(data_loader)
        data = next(data_iter) 

Could do the trick (I did not test the code, just typed it here)

9 Likes

Oh yeah that’s the one. I also just figured it out! Thanks!

Yes, that will be 2 full iterations. And you can even disregard the shuffling of the data indices by the dataloader and follow your own stragety in the next function. For examle irrespective of ehat index the dataloader sends, you can calculate your own index based on your dataset, that way your strategy would not be affected even if you put a huge nunber in len. It’s all subjective how you want to design your dataset.

This should do the trick.

def loopy(dl):
    while True:
        for x in iter(dl): yield x

Then, everywhere you used to use iter(dataloader), use loopy(dataloader).

Beware though, that samplers passed to the DataLoader can raise StopIteration() exceptions.

3 Likes

@justusschock how can I make the data loader shuffle the data every time its iterator is reinitialized?

@AlexisW maybe you know?

Official non-hacky support for this is happening in https://github.com/pytorch/pytorch/pull/19228

4 Likes

As @SimonW mentioned, the release of PyTorch 1.2 brought with it a new dataset class: torch.utils.data.IterableDataset.

Here you can read the official documentation related to it IterableDataset

from torch.utils.data import DataLoader, Dataset, Sampler
import random

class listDataset(Dataset):
    def __init__(self):
        self.varList = [1,2,3,4]
    def __len__(self):
        return len(self.varList)
    def __getitem__(self, idx) :
        return self.varList[idx]

class customSampler(Sampler) :
    def __init__(self, dataset, shuffle):
        assert len(dataset) > 0
        self.dataset = dataset
        self.shuffle = shuffle
    
    def __iter__(self):
        order = list(range((len(self.dataset))))
        idx = 0
        while True:
            yield order[idx]
            idx += 1
            if idx == len(order):
                if self.shuffle:
                    random.shuffle(order)
                idx = 0

if __name__ == "__main__":
    dset = listDataset()
    sampler = customSampler(dset, shuffle=True)
    loader = iter(DataLoader(dataset=dset, sampler=sampler, batch_size=6, num_workers=2))
    for x in range(10):
        i = next(loader)
        print(i)
    

The dataloader without StopIteration can be made like,

import torch
from torch.utils.data import DataLoader

class InfiniteDataLoader:
    def __init__(self, data_loader):
        self.data_loader = data_loader
        self.data_iter = iter(data_loader)

    def __iter__(self):
        return self

    def __next__(self):
        try:
            data = next(self.data_iter)
        except StopIteration:
            self.data_iter = iter(self.data_loader)  # Reset the data loader
            data = next(self.data_iter)
        return data

if __name__ == "__main__":
    # Replace this with your actual data loader creation
    data_loader = DataLoader(dataset=your_dataset, batch_size=your_batch_size, shuffle=True)

    infinite_loader = InfiniteDataLoader(data_loader)

    # You can now use the infinite_loader in a for loop or any other iteration
    for data in infinite_loader:
        # Your data processing code here

If you are using DistributedSampler, you may need the following

class InfiniteDataLoader:
    def __init__(self, data_loader):
        self.data_loader = data_loader
        self.data_iter = iter(data_loader)
        if not hasattr(self.data_loader.sampler, 'epoch'): 
            self.data_loader.sampler.epoch = 0

    def __iter__(self):
        return self

    def __next__(self):
        try:
            data = next(self.data_iter)
        except StopIteration:
            self.data_loader.sampler.epoch += 1
            self.data_iter = iter(self.data_loader)  # Reset the data loader
            data = next(self.data_iter)
        return data