Why Dataset is not converting to list as expected?

I’m trying to do list(dataset) but it returns more than expected. In some setups it runs out of bounds. Sample code:

from torch.utils.data import Dataset

class NumbersDataset(Dataset):
    def __init__(self):
        self.samples = list(range(1, 1001))

    def __len__(self):
        return len(self.samples)-500

    def __getitem__(self, idx):
        return self.samples[idx]


dataset = NumbersDataset()

len(dataset)
> 500

len(list(dataset))
> 1000

I would expect list(dataset) to return 500, but somehow this is ignored.
The same problem happens when you iterate over the dataset with for loop.
I mean how does it even know that there are 1000 elements? :slight_smile:

Any ideas?

You see this because len() and __ len __() are not the same things, as explained here:

https://kushaldas.in/posts/len-function-in-python.html

If you do:

print(dataset.__len__())

you will get 500 as you expect.

There’s no problem with len. See

print(dataset.__len__(), len(dataset))
> 500 500

# But...
len(list(dataset))
> 1000

Both len return 500. But len(list(dataset)) gives 1000…

I think the following example might help in understanding

>>> class X:
...     def __init__(self):
...             self.arr = [1, 2, 3]
...     def __getitem__(self, idx):
...             print('getting: ', idx)
...             return self.arr[idx]
...
>>> x = X()
>>> list(x)
getting:  0
getting:  1
getting:  2
getting:  3
[1, 2, 3]

I guess the list function results in calling the getitem function for each element.
Since, the dataset class is pretty much an abstract class in which one has to implement the __getitem__function. The len function allows one to have a custom size for the dataset separate from the actual iteratable object size.

Indeed it calls __getitem___. But list() uses __getitem__ with more indexes than expected.
In my first example the __len__ returns 500. I’d expect list() to call __getitem__ 500 times. But it calls it 1000 times.

I think it calls getitem for as many elements in the iterable regardless of what len returns.

>>> class X:
...     def __init__(self):
...             self.arr = [1, 2, 3, 4, 5]
...     def __getitem__(self, idx):
...             print('getting: ', idx)
...             return self.arr[idx]
...     def __len__(self):
...             return 1
...
>>> x = X()
>>> list(x)
getting:  0
getting:  1
getting:  2
getting:  3
getting:  4
getting:  5
[1, 2, 3, 4, 5]

Correct. And that is exactly why I am confused. How it even knows that it should call it 5 times (or in fact even 6!).

__getitem__ doesn’t tell how many items are there. list() doesn’t know that self.arr is the iterable it needs to iterate over. I thought that it should check len(dataset) first and then iterate over the length of the dataset. This sounds logical to me.

And I fail to understand how it iterates over more items…

Btw @user_123454321
why it prints getting: 5 and doesn’t raise a IndexError because there is no item with index 5?

Because the samples list has 1000 elements. When you call list(dataset), it would contain 1000 elements.

Ah, interesting. I think it goes until it encounters IndexError.

>>> class X:
...     def __init__(self):
...             self.arr = [1, 2, 3, 4, 5]
...             self.arr2 = [1, 2, 3]
...     def __getitem__(self, idx):
...             print('getting: ', idx)
...             return self.arr[idx], self.arr2[idx]
...     def __len__(self):
...             return 1
...
>>> x = X()
>>> list(x)
getting:  0
getting:  1
getting:  2
getting:  3
[(1, 1), (2, 2), (3, 3)]
1 Like

@1chimaruGin let’s have another example:

from torch.utils.data import Dataset

class NumbersDataset(Dataset):
    def __init__(self):
        self.samples1 = list(range(1, 1001))
        self.samples2 = list(range(1, 11))

    def __len__(self):
        return len(self.samples1)-500

    def __getitem__(self, idx):
        return self.samples1[idx]
      
    def __iter__(self):
      return iter(self.samples1)


dataset = NumbersDataset()

len(dataset), len(list(dataset))
>>> 500, 1000

Ho does list() know that it should iterate over .samples1 and not .samples2 which has only 10 items?

Probably because samples2 are not accessed in getitem ?

@user_123454321 I thought so too. But it doesn’t stop at Index Error:

from torch.utils.data import Dataset

class NumbersDataset(Dataset):
    def __init__(self):
        self.samples = [1,2,3,4,5]

    def __len__(self):
        return 3

    def __getitem__(self, idx):
        print(f"Trying {idx}")
        if idx>3: raise IndexError
        return self.samples[idx]



dataset = NumbersDataset()

len(dataset), len(list(dataset))
>>> (3,5)

list(dataset)
>>> [1, 2, 3, 4, 5]

In python, __getitem__ in class make it iterable.

stackoverflow Q&A about getitem.

I am not sure why you added iter, but it does stop at IndexError without it

>>> class NumbersDataset:
...     def __init__(self):
...         self.samples = [1,2,3,4,5,6,7,9,10]
...     def __len__(self):
...         return 3
...     def __getitem__(self, idx):
...         print("Trying", idx)
...         if idx>3:
...                 print("Raising Index error")
...                 raise IndexError
...         return self.samples[idx]
...
>>> dataset = NumbersDataset()
>>>
>>> len(dataset), len(list(dataset))
Trying 0
Trying 1
Trying 2
Trying 3
Trying 4
Raising Index error
(3, 4)

I think iter bypasses the getitem

Thanks to @user_123454321 we have an answer.
Indeed, __len__ doesn’t influence how an iterable is converted to a list or iterated over in for loops. In fact, you have to implement and IndexError which signals the end of the iteration. This is also described in the __getitem__ docs

So, the correct implementation is as follows:

from torch.utils.data import Dataset

class NumbersDataset(Dataset):
  def __init__(self):
    self.samples = [1,2,3,4,5]

  def __len__(self):
    return 3

  def __getitem__(self, idx):
    if idx>2: 
      raise IndexError
    return self.samples[idx]


dataset = NumbersDataset()

len(dataset), len(list(dataset))
>>> 3, 3

list(dataset)
>>> [1, 2, 3]

Actually, I have been using this discrepancy in the actual length of the iterable vs what is returned in len for randomly selecting equal number (say 300) of augmented data points for an epoch disregarding the actual number of images using something like

class Dataset(torch.utils.data.Dataset):
  def __init__(self):
      self.samples = list(range(100))
  def __getitem__(self, index):
      act_index = random.randint(0, len(self.images))
      return self.samples[act_index]
  def __len__(self):
      return 300

This way I get to always have 300 points from the dataloader.

This may work under certain circumstances. But I think the for loops and list(dataset) will be infinite and there is no IndexError, hence, no end to the iteration.

PS. Assuming the implementation is like this
act_index = random.randint(0, len(self.images)-1)
Minus 1 is missing.

Yeah, it would go infinite, but I don’t call list on it. And considering the number of images, I think I would crash RAM in the original (non-infinite) case.

In fact, for this use case I’d recommend using Sampler. It is meant to define the item picking strategy. So, it is more in line with Pytorch code.
https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler

1 Like