Enumerate(dataloader) slow

Calls to enumerate(dataloader) are quite slow in my project. I’m loading images stored in LMDB format, and I have multiple LMDBs that I call ConcatDataset on to create the final dataset. I noticed that reducing the num_workers from say 6 to 1 or 2 reduces the proportion of time spent on enumerate(), though that slows the loading of images. I was previously passing in a WeightedRandomSampler() into the dataloader and shuffle=False, but when I tried shuffle=True and sampler=None, that also reduced the amount of slowness around enumerate(). Despite these changes, about 20-30% of the training time goes to waiting for enumerate(dataloader) to start each epoch. My code looks like this:

for epoch in range(start_epoch, total_epochs+1):
    for _,train_data in enumerate(train_loader):
        #do work

I have confirmed that the hanging/slowness is due to enumerate and not the work being done during each epoch. Looking at the documentation (torch.utils.data — PyTorch 2.1 documentation) I can see that the slowness is likely due to a large amount of work being done in each call of enumerate():

In this mode, each time an iterator of a DataLoader is created (e.g., when you call enumerate(dataloader) ), num_workers worker processes are created. At this point, the dataset , collate_fn , and worker_init_fn are passed to each worker, where they are used to initialize, and fetch data. This means that dataset access together with its internal IO, transforms (including collate_fn ) runs in the worker process. … Workers are shut down once the end of the iteration is reached, or when the iterator becomes garbage collected.

My question is, what can we do to eliminate the slowness around enumerate? Is it possible to keep those processes generated in the first call to enumerate and pass them a different shuffling of the data? Or any other ideas?

Other info: I’m using the Apex mixed precision package, training on V100 GPUs, and using torch.distributed.DataParallel() on 4 GPUs. Trying apex.parallel.DistributedDataParallel did not significantly reduce the slowness. The hanging is still significant when training on one or two GPUs.

1 Like

If your Dataset.__init__ method is slow due to some heavy data loading, you would see the slowdown in each new creation of the workers.
The recreation of the workers might yield a small slowdown, but should be negligible, if you are using lazy loading and don’t need a lot of resources in the __init__ method.

Could you check, which operations are used in the __init__?

2 Likes

Here’s what my init looks like for the dataset class:

   def __init__(self, opt):
        self.GT_env = lmdb.open(opt['dataroot_GT'], readonly=True, lock=False, readahead=False, meminit=True)
        keys = pickle.load(open(os.path.join(opt['dataroot_GT'], '_keys_cache.p'), "rb"))
        self.paths_GT = sorted([key for key in keys if not key.endswith('.meta')])

        self.LQ_env = lmdb.open(opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, meminit=True)
        keys = pickle.load(open(os.path.join(opt['dataroot_LQ'], '_keys_cache.p'), "rb"))
        self.paths_LQ = sorted([key for key in keys if not key.endswith('.meta')])

It takes .005 seconds to run init per dataset, and I concatenate 8 datasets to make the final dataset. Let me know if you have any suggestions. Thanks for looking!

@vibe2 A researcher submitted a solution to this issue to my codebase a while back, apparently it has also been discussed in a pytorch issue/PR for sometime

And the PR that was submitted with some info about the solution from the author: https://github.com/rwightman/pytorch-image-models/pull/140

Works great. Thank you!

1 Like

One alternative is to use the parameter persistent_workers available in DataLoader
persistent_workers (bool, optional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)

3 Likes

Note it isn’t the same. This implementation maintains the workers alive but also pre-fetches for the next epoch as the sampler is infinite. My understanding is that PyTorch’s persistent-workers feature won’t pre-fetch for the next epoch as the sampler would be empty.

1 Like