DataLoader with num_workers>0 and multiple epochs

Hi,
This may have been asked before, but I’m not sure what good keywords may be.

Let’s say we are in a setting where fetching data items take a while (e.g. long videos). Let’s say we have a DataLoader with num_workers>0. We have an inner loop where we iterate over the DataLoader, and an outer loop where we iterate over epochs to make sure we see each item more than once.
As far as I know, the whole point about using num_workers>0 is to have the DataLoader spawn multi processes that fetch the next items ahead of time. No matter what we do, there’s going to be an inevitable waiting time in the first iteration - the model is free to process items and waits the (long, in our case) time for the DataLoader to fetch the first item. Assuming the DataLoader works faster than the model, this effect later disappears.

My question: what happens in the second epoch? To my understanding, since we finished iterating over all items in the first epoch and we are now starting a new iterator, the DataLoader never knows to continue loading additional items ahead of time as we reach the first epoch’s end.

Am I correct?

If so, are there any recommendations for how to restructure things to avoid the wait in the start of every epoch? I thought about pretending the dataset has len(dataset) * len(epochs) items and having a single loop, but before I restructure my code I wanted to make sure that (1) the problem is real and (2) the problem has not been solved more elegantly yet.

Thanks!

That’s correct in the default setup using multiple workers, but wrong if persistent_workers=True is used. From the docs:

persistent_workers (bool, optional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)

Thanks @ptrblck, you are absolutely correct (as always). At first glance I was suspicious that keeping the workers alive is not equivalent to loading the next epoch’s items ahead, but they do seem to do that, removing the need for any tailored solution I’ve offered before.

One thing still bothers me, though. To test the claim, I’ve simulated the scenario and constructed a toy example with both a “dataset” and “model” that do nothing but sleep a defined amount of time and print.
In a multiple epochs setting, I see that indeed the data items for the second epoch are getting loaded before the first epoch finishes, when the workers finish loading the previous iteration’s items.
When I change the code to have a single epoch though, it seems items for a theoretical second epoch are not loaded - how does the loader know when to load and when not to load additional items? Here’s what I mean:

from time import sleep
from torch.utils.data import DataLoader


# Dummy dataset

class MyDataSet:

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        print(f'(d) start {index}')
        sleep(6) # Simulating some heavy-lifting code
        print(f'(d) finish {index}')
        return self.data[index]

# First version - two epochs

if __name__ == '__main__':
    dataset = MyDataSet(list(range(10)))
    dl = DataLoader(dataset, num_workers=2, persistent_workers=True)
    for epoch in range(2):
        print(f'Starting epoch {epoch}')
        for batch in dl:
            print(f'(m) start {batch.item()}')
            sleep(2) # Simulating some heavy-lifting code, faster than data loading
            print(f'(m) finish {batch.item()}')
        print(f'Finish epoch {epoch}')

This gives the following output:

Starting epoch 0
(d) start 1
(d) start 0
(d) finish 1(d) finish 0

(d) start 3(d) start 2

(m) start 0
(m) finish 0
(m) start 1
(m) finish 1
(d) finish 3
(d) start 5
(d) finish 2
(d) start 4
(m) start 2
(m) finish 2
(m) start 3
(m) finish 3
(d) finish 5
(d) start 7
(d) finish 4
(d) start 6
(m) start 4
(m) finish 4
(m) start 5
(m) finish 5
(d) finish 7
(d) start 9
(d) finish 6
(d) start 8
(m) start 6
(m) finish 6
(m) start 7
(m) finish 7
(d) finish 9
(d) finish 8
(m) start 8
(m) finish 8
(m) start 9
(d) start 1(d) start 0

(m) finish 9
Finish epoch 0
Starting epoch 1
(d) finish 1
(d) finish 0(d) start 3

(d) start 2
(m) start 0
(m) finish 0
(m) start 1
(m) finish 1
(d) finish 3
(d) finish 2
(d) start 4
(d) start 5
(m) start 2
(m) finish 2
(m) start 3
(m) finish 3
(d) finish 4
(d) start 6
(d) finish 5
(d) start 7
(m) start 4
(m) finish 4
(m) start 5
(m) finish 5
(d) finish 6
(d) start 8
(d) finish 7
(d) start 9
(m) start 6
(m) finish 6
(m) start 7
(m) finish 7
(d) finish 8
(d) finish 9
(m) start 8
(m) finish 8
(m) start 9
(m) finish 9
Finish epoch 1

Where the interesting part is the following:

(m) finish 8
(m) start 9
(d) start 1(d) start 0

(m) finish 9
Finish epoch 0
Starting epoch 1

i.e the dataloaders start loading items 0 and 1 while the main process is still processing item 9 and notably, before it finishes epoch 0.

However when I remove the for loop over epochs, to get:

if __name__ == '__main__':
    dataset = MyDataSet(list(range(10)))
    dl = DataLoader(dataset, num_workers=2, persistent_workers=True)
    print(f'Starting epoch 0') # No outer loop now
    for batch in dl:
        print(f'(m) start {batch.item()}')
        sleep(2)
        print(f'(m) finish {batch.item()}')
    print(f'Finish epoch 0')

I get the following output:

Starting epoch 0
(d) start 0
(d) start 1
(d) finish 0
(d) finish 1(d) start 2

(d) start 3
(m) start 0
(m) finish 0
(m) start 1
(m) finish 1
(d) finish 2
(d) start 4
(d) finish 3
(d) start 5
(m) start 2
(m) finish 2
(m) start 3
(m) finish 3
(d) finish 4
(d) start 6
(d) finish 5
(d) start 7
(m) start 4
(m) finish 4
(m) start 5
(m) finish 5
(d) finish 6
(d) start 8
(d) finish 7
(d) start 9
(m) start 6
(m) finish 6
(m) start 7
(m) finish 7
(d) finish 8
(d) finish 9
(m) start 8
(m) finish 8
(m) start 9
(m) finish 9
Finish epoch 0

Notice how now the dataloader halts and does not reload items 0 and 1. Although loading them is indeed not required, how does it distinguish between the two cases (as supposedly it sees the same code in both cases)? This is what puzzles me.