Is there a good way to have an infinite dataloader? That said, is there a class that will provide automatically looping for method like data_loader.get_next()
? And how to maintain full iterations?
Well one quick and dirty hack would be for your CustomDataset to return a very high number (e.g. np.iinfo(np.int64).max) in its __len__
.
As for get_next(), you can get the iterator from the dataloader and call next on that:
next(dataloader.__iter__())
Thanks! However, in this way, will the data sampling strategy be affected? That said, if I have 10 batches for the data, and I do next for 20 times, will that be 2 full iteration?
Any thoughts?
Alternatively you should be able to just catch the StopIteration exception and reinitialize the generator (maybe with some previous shuffling)
Could you provide a basic example on how to do so? Thanks!!
Also it seems that every time I called next(dataloader.__iter__())
, I will get the same data.
next(data loader. __iter__())
Creates a new iterator starting at the beginning of the dataset every time you call it.
# create dataloader-iterator
data_iter = iter(data_loader)
# iterate over dataset
# alternatively you could use while(True)
for i in range(NUM_ITERS_YOU_WANT)
try:
data = next(data_iter)
except StopIteration:
# StopIteration is thrown if dataset ends
# reinitialize data loader
data_iter = iter(data_loader)
data = next(data_iter)
Could do the trick (I did not test the code, just typed it here)
Oh yeah that’s the one. I also just figured it out! Thanks!
Yes, that will be 2 full iterations. And you can even disregard the shuffling of the data indices by the dataloader and follow your own stragety in the next function. For examle irrespective of ehat index the dataloader sends, you can calculate your own index based on your dataset, that way your strategy would not be affected even if you put a huge nunber in len. It’s all subjective how you want to design your dataset.
This should do the trick.
def loopy(dl):
while True:
for x in iter(dl): yield x
Then, everywhere you used to use iter(dataloader)
, use loopy(dataloader)
.
Beware though, that samplers passed to the DataLoader can raise StopIteration()
exceptions.
@justusschock how can I make the data loader shuffle the data every time its iterator is reinitialized?
@AlexisW maybe you know?
Official non-hacky support for this is happening in https://github.com/pytorch/pytorch/pull/19228
As @SimonW mentioned, the release of PyTorch 1.2 brought with it a new dataset class: torch.utils.data.IterableDataset.
Here you can read the official documentation related to it IterableDataset
from torch.utils.data import DataLoader, Dataset, Sampler
import random
class listDataset(Dataset):
def __init__(self):
self.varList = [1,2,3,4]
def __len__(self):
return len(self.varList)
def __getitem__(self, idx) :
return self.varList[idx]
class customSampler(Sampler) :
def __init__(self, dataset, shuffle):
assert len(dataset) > 0
self.dataset = dataset
self.shuffle = shuffle
def __iter__(self):
order = list(range((len(self.dataset))))
idx = 0
while True:
yield order[idx]
idx += 1
if idx == len(order):
if self.shuffle:
random.shuffle(order)
idx = 0
if __name__ == "__main__":
dset = listDataset()
sampler = customSampler(dset, shuffle=True)
loader = iter(DataLoader(dataset=dset, sampler=sampler, batch_size=6, num_workers=2))
for x in range(10):
i = next(loader)
print(i)
The dataloader without StopIteration
can be made like,
import torch
from torch.utils.data import DataLoader
class InfiniteDataLoader:
def __init__(self, data_loader):
self.data_loader = data_loader
self.data_iter = iter(data_loader)
def __iter__(self):
return self
def __next__(self):
try:
data = next(self.data_iter)
except StopIteration:
self.data_iter = iter(self.data_loader) # Reset the data loader
data = next(self.data_iter)
return data
if __name__ == "__main__":
# Replace this with your actual data loader creation
data_loader = DataLoader(dataset=your_dataset, batch_size=your_batch_size, shuffle=True)
infinite_loader = InfiniteDataLoader(data_loader)
# You can now use the infinite_loader in a for loop or any other iteration
for data in infinite_loader:
# Your data processing code here
If you are using DistributedSampler
, you may need the following
class InfiniteDataLoader:
def __init__(self, data_loader):
self.data_loader = data_loader
self.data_iter = iter(data_loader)
if not hasattr(self.data_loader.sampler, 'epoch'):
self.data_loader.sampler.epoch = 0
def __iter__(self):
return self
def __next__(self):
try:
data = next(self.data_iter)
except StopIteration:
self.data_loader.sampler.epoch += 1
self.data_iter = iter(self.data_loader) # Reset the data loader
data = next(self.data_iter)
return data