Dataloader returns multiple index

Here is a reproducible code:

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader

class test(Dataset):
    def __init__(self):
        super(test, self).__init__()
        
    def __getitem__(self, idx):
        print(idx)
        return idx

    def __len__(self):
        return 100

test = test()
loader = DataLoader(dataset=test,
                    batch_size=2,
                    num_workers=1)

for data in loader:
    print(data)
    break

and it will print:

0
1
2
3
4
5
tensor([0, 1])

but why the index get printed multiple times? and if you change the batch_size to 1, it will give:

0
1
2
tensor([0])

torch==1.5.1
torchvision==0.6.0a0+35d732a

@ptrblck sorry to bother you but if you got some spare time, maybe take a look? Really appreciate

I’m not sure, if I understand the question correctly, but I think you are concerned about the print(idx) line and are wondering why it’s printing so many indices before yielding the first batch?

If you are using multiprocessing, the workers will load multiple batches in the background and add them to a queue.
Since the batch creation in your code snippet is really simple, the background worker is able to fetch 3 batches before the DataLoader loop returns the first data batch.

Setting num_workers=0 will only print the 0 and 1 indices before returning the batch.

yes your understanding of the question is correct.

Since the batch creation in your code snippet is really simple, the background worker is able to fetch 3 batches before the DataLoader loop returns the first data batch.

acually I found this problem is because I have a dataloader with data augmentation and I want to check which image got loaded, then I try to print out the index. Thus I am not sure this is because the dataloader is simple, thus it can fetch 3 items and print out.

so what I worry is, if the batch size is large, say 128, when we load a batch, will it become slower just because we keep generating the index? because I created a notebook to test this and it seems when batch size is large, the index will get printed so many times.

but I just dont know if this kind of “slow” is because we have large batch, or becuase the loader generate index many times thus introduce some extra loading time.

Thanks for your reply.

No, that shouldn’t be the case as the index generation should be really cheap in comparison to the data loading and processing. The batch is calculated by calling __getitem__ with batch_size indices.
E.g. for a batch size of 128 the DataLoader would call Dataset.__getitem__ 128 times and thus you would see the index output 128 times.

Most likely you see the slowdown due to the data loading and processing.
Note that multiple workers would potentially hide this latency in the background after the first iteration started.

Thanks for the explaination!

seems i need to set num_workers=0 in order to get the index correctly printed.