Data loading with massive batches?

Hi, I am using a batch size of 640 on the following model:

It’s a simple CNN with a relation network. Most of the computation time is spent merely on loading the data. In fact, I checked the GPU usage and most of the time they’re not being used. I have 4 GPUs and they seem to process the data quickly, once it’s loaded. I also have 40 CPUs serving as workers.

However, I am in a hurry and the training time is too long with my current setup (3 weeks!). The bottleneck is the data loading I believe. See the __getitem__ and collate functions in the dataset.py of the link above.

Is there any way to speed up data loading with such large image batches? Perhaps a way to leverage my 40 CPUs? Any help would be greatly appreciated; this is particularly urgent.

Generally, multiple workers should speedup your data loading, although the sweet spot depends on your setup.
You could have a look at this post for more information about potential bottlenecks.
Also, you might want to use DALI in case it fits your use case.

Not to be a nuisance, but is there anything obvious that stands out to you in this data loader? https://github.com/rosinality/relation-networks-pytorch/blob/master/dataset.py

The collate_data function uses a sort, but getting rid of it didn’t fix the intermittent pauses throughout training. If you could take a glimpse and if something stands out to you that I didn’t catch, it would be really appreciated.

The Dataset looks good and apparently the custom collate_fn doesn’t change the slow down.
You could install PIL-SIMD to speed up the image transformations.
PIL-SIMD is a drop-in replacement for PIL and you don’t have to change any calls, so this might be worth a try.