Why does my custom dataset class loop forever?

I wrote my own custom dataset class but when I try to iterate through its data one by one I get an infinite loop.

I went to the extreme and have the __len__ method always return 0 and that didn’t stop it from continually looping through my dataset. How do I stop it? Why does:

for i, data in enumerate(dataset):
    print(i)
    print(data)

why does this keep calling the next method? How does it stop?


In the hope of having a reproducible error I coded this:

class MyDataset(Dataset):
    def __init__(self):
        self.bob = [0,1,2]

    def __len__(self):
        return 0

    def __getitem__(self, idx):
        print(f'idx = {idx}')
        return self.bob[idx]

now it does stop even though __len__ is zero…

Hi,
I tried this-

class MyDataset(Dataset):
    def __init__(self):
        self.bob = [0,1,2]
    def __len__(self):
        return len(self.bob)
    def __getitem__(self, idx):
        print(f'idx = {idx}')
        return self.bob[idx]

dataset = DataLoader(MyDataset(), batch_size = 1, shuffle = True, num_workers = 0)

for i, data in enumerate(dataset):
    print(i)
    print(data)

And got the following output that did not loop infinitely-

idx = 1
0
tensor([1])
idx = 0
1
tensor([0])
idx = 2
2
tensor([2])

Is your DataLoader() correct, because your custom Dataset looks correct.

I think I figured it out. I think its that the stopping condition of looping through the dataset has to be using the len function. Though for some reason it DOES work on the strange example I cooked up without looping forever…so it might be a weird edge case. I expect once I wrap it with the dataloader class everything should work fine (I hope).

Make sure your __getitem__ raises an IndexError for illegal indexes. If your __getitem__ function never raises an exception then it will loop forever. The for loop doesn’t make use of __len__.

The example you posted will raise an IndexError at the correct time because __getitem__ with 3 calls self.bob[3] which raises the error.

https://docs.python.org/3/reference/datamodel.html

1 Like

I ran it in python’s unit test with no errors. You sure it does raise an error in yours with that simple/trivial example?

You sure it does raise an error in yours with that simple/trivial example?

Yes. Notice that it prints idx = 3. The evaluation of self.bob[3] raises an IndexError. The Python interpreter catches this error and stops the for loop.

This doesn’t happen if you wrap it in a DataLoader like in @Prerna_Dhareshwar’s example because DataLoader uses __len__.

Here is the full runnable example:

from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self):
        self.bob = [0,1,2]

    def __len__(self):
        print('len')
        return 0

    def __getitem__(self, idx):
        print(f'idx = {idx}')
        return self.bob[idx]


dataset = MyDataset()
for i, data in enumerate(dataset):
    print(i)
    print(data)

Here is the output:

idx = 0
0
0
idx = 1
1
1
idx = 2
2
2
idx = 3