I wrote my own custom dataset class but when I try to iterate through its data one by one I get an infinite loop.
I went to the extreme and have the __len__ method always return 0 and that didn’t stop it from continually looping through my dataset. How do I stop it? Why does:
for i, data in enumerate(dataset):
print(i)
print(data)
why does this keep calling the next method? How does it stop?
In the hope of having a reproducible error I coded this:
I think I figured it out. I think its that the stopping condition of looping through the dataset has to be using the len function. Though for some reason it DOES work on the strange example I cooked up without looping forever…so it might be a weird edge case. I expect once I wrap it with the dataloader class everything should work fine (I hope).
Make sure your __getitem__ raises an IndexError for illegal indexes. If your __getitem__ function never raises an exception then it will loop forever. The for loop doesn’t make use of __len__.
The example you posted will raise an IndexError at the correct time because __getitem__ with 3 calls self.bob[3] which raises the error.
You sure it does raise an error in yours with that simple/trivial example?
Yes. Notice that it prints idx = 3. The evaluation of self.bob[3] raises an IndexError. The Python interpreter catches this error and stops the for loop.
This doesn’t happen if you wrap it in a DataLoader like in @Prerna_Dhareshwar’s example because DataLoader uses __len__.
Here is the full runnable example:
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self):
self.bob = [0,1,2]
def __len__(self):
print('len')
return 0
def __getitem__(self, idx):
print(f'idx = {idx}')
return self.bob[idx]
dataset = MyDataset()
for i, data in enumerate(dataset):
print(i)
print(data)