How to prefetch data when processing with GPU?

I’ve spent some time working on this problem over a variety of projects. I’ve cut and pasted some past thoughts, bullets, etc from previous discussions. My background involves architecting systems that move large volumes of data from network cards to storage and then back again on request, with processing between the steps. A very similar set of concerns.

The two main constraints that usually dominate your PyTorch training performance and ability to saturate the shiny GPUs are your total CPU IPS (instructions per second) and your storage IOPS (I/O per second).

You want the CPUs to be performing preprocessing, decompression, and copying – to get the data to the GPU. You don’t want them to be idling or busy-waiting for thread/process synchronization primitives, IO, etc. The easiest way to improve CPU utilization with the PyTorch is to use the worker process support built into Dataloader. The preprocessing that you do in using those workers should use as much native code and as little Python as possible. Use Numpy, PyTorch, OpenCV and other libraries with efficient vectorized routines that are written in C/C++. Looping through your data byte by byte in Python will kill your performance, massacre the memory allocator, etc.

With most common use cases, the preprocessing is done well enough to not be an issue. Things tend to fall apart dramatically hitting the IOPS limit of your storage. Most simple PyTorch datasets tend to use media stored in individual files.

  • Modern filesystems are good, but when you have thousands of small files and you’re trying to move GB/s of data, reading each file individually can saturate your IOPS long before you can ever maximize GPU or CPU utilization.
  • Just opening a file by name with PIL can be an alarming number of disk seeks (profile it)
  • Quick fix: buy an NVME SSD drive, or two.
  • SATA SSD is not necessarily going to cut it, you can saturate them with small to medium image files + default loader/dataset setups feeding multiple GPUs.
  • Magnetic drives are going to fall on their face
  • If you are stuck with certain drives or max out the best, the next move requires more advanced caching, prefetching, on-disk format – move to an index/manifest + record based bulk data (like tfrecord or RecordIO) or an efficient memory-mapped/paged in-process DB
  • I’ve leveraged LMDB successfully with PyTorch and a custom simplification of the Python LMDB module. My branch here (https://github.com/rwightman/py-lmdb/tree/rw). I didn’t document or explain what I did there or why, ask if curious.
  • Beyond an optimal number (experiment!), throwing more worker processes at the IOPS barrier WILL NOT HELP, it’ll make it worse. You’ll have more processes trying to read files at the same time, and you’ll be increasing the shared memory consumption by significant amounts for additional queuing, thus increasing the paging load on the system and possibly taking you into thrashing territory that the system may never recover from
  • Once you have saturated the IOPS of your storage or taxed the memory subsystem and entered into a thrashing situation, it won’t look like you’re doing a whole lot. There will be a lot of threads/processes (including kernel ones) basically not doing much besides waiting for IO, page faults, etc. Behaviour will usually be sporadic and bursty once you cross the line of what can be sustained by your system, much like network link utilization without flow control (queuing theory).

Other pointers for a fast training setup with minimal work over the defaults:

  • Employ some of the optimizations in NVIDIA’s examples (https://github.com/NVIDIA/apex/tree/master/examples/imagenet). NVIDIA’s fast_collate and prefetch loader w/ GPU normalization step do help a bit.
  • I’ve seen big gains over torch.DataParallel using apex.DistributedDataParallel. Moving from ‘one main process + worker process + multiple-GPU with DataParallel’ to 'one process-per GPU with apex (and presumably torch)
  • DistributedDataParallel has always improved performance for me. Remember to (down)scale your worker processes per training process accordingly. Higher GPU utilization and less waiting for synchronization usually results, the variance in batch times will reduce with the average time moving closer to the peak.
  • Use SIMD fork of Pillow with default PyTorch transforms, or write your own OpenCV image processing and loading routines
  • Don’t leave the dataloader pin_memory=‘True’ on by default in your code. There was a reason why PyTorch authors left it as False. I’ve run into many situations where True definitely does cause extremely negative paging/memory subsystem impact . Try both.

An observation on the tfrecord/recordio chunking. For IO, even flash based, randomness is bad, sequential chunks are good. Hugely so when you have to move physical disk heads. The random/shuffled nature of training is thus worst case. When you see gains using record/chunked data, it’s largely due to the fact that you read data in sequential chunks. This comes with a penalty. You’ve constrained how random your training samples can be from epoch to epoch. With tfrecord, you usually shuffle once when you build the the tfrecord chunks. When training, you can only shuffle again within some number of queued record files (constrained by memory), not across the whole dataset. In many situations, this isn’t likely to cause problems, but you do have to be aware for each use case, tune your record size to balance performance with how many samples you’d like to shuffle at a time.

154 Likes
Dataloaders and Cuda management
Non-blocking transfer to GPU is not working
Pytorch cuda training speed comparison between data on Nvme vs Sata ssd
Input numpy ndarray instead of images in a CNN
My GPU is dead while using Nvidia Apex
GPU memory is in normal use, but GPU-util is 0%
Dataloader caching on large datasets
Why the more num_work the time for inference is amazingly longer?
Runtime slowdown related to number of datasets
Iterating batch in DataLoader takes more time after every nth iteration
Bottleneck on data loading
How to make current data loader more efficient
Speed up CNN pytorch
Dataloader SuperSuper Slow
CPU bottlecking GPU
Define iterator on Dataloader is very slow
Speed up image loading in CPU before transferring to GPU
Handling large 3d image dataset with DataLoader
Questions after following transfer learning tutorial
NumpyDataset - Performance Analysis
Effect of python eval() function on GPU training
Possible deadlock? Training example gets stuck
Data loading time is nearly proportional to batchsize
Pytorch data loading best practices, any good resources to explore?
Best practices for network bottlenecked image data loading + transforms
Time/Memory keeps increasing at every iteration
Is a Dataset copied as part of Dataloader with multiple workers?
My model is is not using GPU after successful conversion though
Best practices when reading a large number of files every dataloader iteration
Training faster with single gpu
DataLoader hangs with custom DataSet
Is there any efficient way to load MS1M dataset?
How to speed up training on ImageNet
Loading data taking a lot of time
Can multiple training runs, all reading the same data on disk, slow each other down?
Dataloader with Numpy much slower when num_workers > 0
ZERO GPU utilization
Low GPU Util with Custom Dataloader Open CV and Numpy Preprocessing
Data loading with massive batches?
Torch has not attribute load_state_dict?
Training "never finishes" or system crashes using PyTorch - GPU has memory allocated but always has 0% utilization using DataLoader
Dataloaders and performance
Training loop takes a long time each epoch using TensorDataset
Pytorch is not using GPU
Example Imagenet ResNet-18 is slow
What is pytorch recommended approach to ensure sequential reads while dataloading read
GPU util has 0-100% fluctuation
Run Pytorch on Multiple GPUs
Slow data loading when training a classification model on ImageNet
Volatile GPU util 0% with high memory usage
Can't get pytorch to use my 2070 GPU
More gpu accelerate the speed of dataloader?
DataLoader efficiency with multiple workers
PyTorch stop using cuda after reboot in Linux Mint 21.3
How to increase GPU utlization
Long training time, bottleneck output shows high IO
Loading a large dataset and training them
GPU not fully used, how to optimize the code
Execution time does not decrease as batch size increases with GPUs
CPU maxed out on training resnext50_32x4d....while gpu not being used hence slow training
DataLoader implicitly using CPU?