Help in understanding how Dataloader works internally

Hello all, I am having trouble understanding how the dataloader works internally, especially when we define the number of workers. I noticed a weird behavior and made a minimal code snippet replicating the issue. Here is the dataset class.


class testClass(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        return 1000

    def __getitem__(self, item):
        print("Accessing the __getitem__ method")
        return torch.rand(10)

I am calling the class as and testing as follows -


dataset = testClass()
dataloader = DataLoader(dataset, batch_size=5, num_workers=5)
for _, data in enumerate(dataloader):
    print(data.shape)
    print('--------------------------------------------------')
    break

when my batch_size is 5, the output is as follows -

Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
torch.Size([5, 10])
--------------------------------------------------
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method

There are 2 things I am not able to understand. firstly, why is the __getitem__ method being called so many times, since my batch is 5, I expect it to be called only 5 times. Secondly, as you can see the why is it being called after I have printed the dashed lines, I have already received my first batch of data and added a break . Nothing should be printed after the dashed lines I suppose. Also the behaviour is not always the same, it sometimes prints or not prints after the dashed lines.

This is also the issue when I specify num_workers=1 and batch_size=1. The output for this combination is as follows-

Accessing the __getitem__ method
Accessing the __getitem__ method
torch.Size([1, 10])
--------------------------------------------------

Again it is being called twice.
The only time I notice the expected behavior is when I do not pass the num_workers argument. For example for batch_size=1 and not passing num_workers the output is as follows -

Accessing the __getitem__ method
torch.Size([1, 10])
--------------------------------------------------

and for batch_size=5 the output is as follows(Again not passing num_workers argument when calling Dataloader).

Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
Accessing the __getitem__ method
torch.Size([5, 10])
--------------------------------------------------

When dealing with my original issue, I realized that I was loading gibberish data all along, that is wrong labels corresponding to input data…
What am I doing wrong here?

Hi,

firstly, why is the __getitem__ method being called so many times, since my batch is 5, I expect it to be called only 5 times

This is because the worker processes are loading batch in advance to be able to provide them to the training as quickly as possible when it asks for a new one.

Secondly, as you can see the why is it being called after I have printed the dashed lines, I have already received my first batch of data and added a break . Nothing should be printed after the dashed lines I suppose.

This most likely happens because the worker that load the data work asynchronously and load the data before the main process tells them to stop doing stuff.

Again it is being called twice.

Again because it preloads 2 batchs in advance. You can check the doc for dataloader on master and in particular the prefetch_factor argument that allows you to control how many batch are loaded in advance: torch.utils.data — PyTorch master documentation

Thanks a lot for your reply @albanD.

Here, I noticed that It loaded the data of the first and label of the second. This is actually the reason I started digging into the Dataloader…
I am still not sure why was that happening.

I’m not sure what you mean by that? There is not concept of data and label here. The Dataset should return the right pair of them for a given item.

Thanks a lot for replying, I will look into it.