from torch.utils.data import Dataset, DataLoader
import time
import multiprocessing as mp
import torch
class Sleep(Dataset):
def __len__(self): return 20
def __getitem__(self, i):
import time, os
time.sleep(1)
return os.getpid()
if __name__ == "__main__":
mp.set_start_method("fork", force=True)
loader = DataLoader(Sleep(), batch_size=20, num_workers=10, persistent_workers=True)
t0 = time.time()
next(iter(loader))
print("wall = ", time.time() - t0) # should be ≈ 3-4 s, not 180 s
next(iter(loader))
print("wall = ", time.time() - t0) # should be ≈ 3-4 s, not 180 s
next(iter(loader))
print("wall = ", time.time() - t0) # should be ≈ 3-4 s, not 180 s
import pdb; pdb.set_trace()
I have this simple test script. On different systems, with at least 10 cpus each, I run it - each time, each batch takes 20 seconds to load, meaning things are running serially and not in parallel with multiprocessing. Why? How do we actually parallelize the dataloader?