Thanks @ptrblck, you are absolutely correct (as always). At first glance I was suspicious that keeping the workers alive is not equivalent to loading the next epoch’s items ahead, but they do seem to do that, removing the need for any tailored solution I’ve offered before.
One thing still bothers me, though. To test the claim, I’ve simulated the scenario and constructed a toy example with both a “dataset” and “model” that do nothing but sleep a defined amount of time and print.
In a multiple epochs setting, I see that indeed the data items for the second epoch are getting loaded before the first epoch finishes, when the workers finish loading the previous iteration’s items.
When I change the code to have a single epoch though, it seems items for a theoretical second epoch are not loaded - how does the loader know when to load and when not to load additional items? Here’s what I mean:
from time import sleep
from torch.utils.data import DataLoader
# Dummy dataset
class MyDataSet:
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
print(f'(d) start {index}')
sleep(6) # Simulating some heavy-lifting code
print(f'(d) finish {index}')
return self.data[index]
# First version - two epochs
if __name__ == '__main__':
dataset = MyDataSet(list(range(10)))
dl = DataLoader(dataset, num_workers=2, persistent_workers=True)
for epoch in range(2):
print(f'Starting epoch {epoch}')
for batch in dl:
print(f'(m) start {batch.item()}')
sleep(2) # Simulating some heavy-lifting code, faster than data loading
print(f'(m) finish {batch.item()}')
print(f'Finish epoch {epoch}')
This gives the following output:
Starting epoch 0
(d) start 1
(d) start 0
(d) finish 1(d) finish 0
(d) start 3(d) start 2
(m) start 0
(m) finish 0
(m) start 1
(m) finish 1
(d) finish 3
(d) start 5
(d) finish 2
(d) start 4
(m) start 2
(m) finish 2
(m) start 3
(m) finish 3
(d) finish 5
(d) start 7
(d) finish 4
(d) start 6
(m) start 4
(m) finish 4
(m) start 5
(m) finish 5
(d) finish 7
(d) start 9
(d) finish 6
(d) start 8
(m) start 6
(m) finish 6
(m) start 7
(m) finish 7
(d) finish 9
(d) finish 8
(m) start 8
(m) finish 8
(m) start 9
(d) start 1(d) start 0
(m) finish 9
Finish epoch 0
Starting epoch 1
(d) finish 1
(d) finish 0(d) start 3
(d) start 2
(m) start 0
(m) finish 0
(m) start 1
(m) finish 1
(d) finish 3
(d) finish 2
(d) start 4
(d) start 5
(m) start 2
(m) finish 2
(m) start 3
(m) finish 3
(d) finish 4
(d) start 6
(d) finish 5
(d) start 7
(m) start 4
(m) finish 4
(m) start 5
(m) finish 5
(d) finish 6
(d) start 8
(d) finish 7
(d) start 9
(m) start 6
(m) finish 6
(m) start 7
(m) finish 7
(d) finish 8
(d) finish 9
(m) start 8
(m) finish 8
(m) start 9
(m) finish 9
Finish epoch 1
Where the interesting part is the following:
(m) finish 8
(m) start 9
(d) start 1(d) start 0
(m) finish 9
Finish epoch 0
Starting epoch 1
i.e the dataloaders start loading items 0 and 1 while the main process is still processing item 9 and notably, before it finishes epoch 0.
However when I remove the for loop over epochs, to get:
if __name__ == '__main__':
dataset = MyDataSet(list(range(10)))
dl = DataLoader(dataset, num_workers=2, persistent_workers=True)
print(f'Starting epoch 0') # No outer loop now
for batch in dl:
print(f'(m) start {batch.item()}')
sleep(2)
print(f'(m) finish {batch.item()}')
print(f'Finish epoch 0')
I get the following output:
Starting epoch 0
(d) start 0
(d) start 1
(d) finish 0
(d) finish 1(d) start 2
(d) start 3
(m) start 0
(m) finish 0
(m) start 1
(m) finish 1
(d) finish 2
(d) start 4
(d) finish 3
(d) start 5
(m) start 2
(m) finish 2
(m) start 3
(m) finish 3
(d) finish 4
(d) start 6
(d) finish 5
(d) start 7
(m) start 4
(m) finish 4
(m) start 5
(m) finish 5
(d) finish 6
(d) start 8
(d) finish 7
(d) start 9
(m) start 6
(m) finish 6
(m) start 7
(m) finish 7
(d) finish 8
(d) finish 9
(m) start 8
(m) finish 8
(m) start 9
(m) finish 9
Finish epoch 0
Notice how now the dataloader halts and does not reload items 0 and 1. Although loading them is indeed not required, how does it distinguish between the two cases (as supposedly it sees the same code in both cases)? This is what puzzles me.