Memory management DataLoader, optimization and realted stuff


I’ve been using PyTorch (Lightning) almost for a year. However, I have little knowledge about CS things (processes, threads, etc.).

Since now, my way of optimizing training time is None and my reasoning is as simple as: more data = more time, more parameters = more time. I wolud like to know how pytorch works with a bit more detail so I can use it optimally, any recomendation for this is welcome. Following this idea, I wolud like to know more details about some things.

  • In GPU training, the dataloader fetches the data from disk, RAM or from where?
  • If the data is transformed, for example with MONAI, torchIO ot torchvision, this transformations happen every epoch?
  • If I transfer data between devices in preprocessing steps with things like monai.transforms.ToDevice, which is basically an enhaced wrapper for to() this will be time-consuming?
  • In multi-GPU training, do I need to adjust my batch size so batches can fit in each GPU or the batch is splitted across GPUS and I can make it bigger?
  • Is there any way to visualize or keep track of the processes created by dataloader? The motivation for this question is that if is set num workers=8 I wolud like to see the eight processes in something like Linux htop.
  • Don’t kill me for asking that, what is rank_zero and why should I know more about it?
  • What is automatic mixed precision and why should I use it/know more about it?

PD: any reference, resource, eplanation, exercise or advice for optimizing training is welcome, thanks.

Generally, I would recommend to take a look at the performance guide if you are trying to optimize your code.

I’ll also try to answer the questions in a general PyTorch setup, as I’m not deeply familiar with Lightning and guess that it might try to optimize the code for you in some of their utils.

It depends on the Dataset implementation and how the data is loaded in the __init__ and __getitem__. E.g. ImageFolder will lazily load each sample from the specified image folders in the __getitem__ to avoid preloading all samples (which might not even fit into your RAM).
Small datasets, such as MNIST, are loaded into the RAM in the __init__ method and each sample is then just indexed and transformed in the __getitem__.

In the standard setup the transformation is applied in the __getitem__ of your Dataset for each sample. This means that each sample in the batch will be loaded and transformed on the fly during the entire epoch.

It depeneds what you mean by “time consuming”. Pushing the data to the GPU will of course take time. However, compared to the actual model training it might be tiny and if you are using pinned memory and an async copy, the CPU would be non-blocking.

In the recommended approach (single process per device using DistributedDataParallel) each process is responsible for the data loading and will thus use its own batch_size. To avoid loading the same samples, a DistributedSampler is usually used.

Yes, the processes should be visible in htop and ps auxf.

I don’t know where rank_zero is coming from but usually ranks are used in a distributed trainig setup.

Automatic mixed precision can speedup your training and save memory especially if you are using “newer” GPUs with TensorCores. Check the amp tutorial and the performance guide for more information.


Hey @ptrblck,

Sorry for the delay, just one detail rank_zero is mentioned several times in Lightning documentation, maybe I could ask someone in their Slack to tell me a little more in depth about it (and read the docs again with a little bit of calm). I dont come from a CS background and I feel lost sometimes with so many new concepts… So I really appreciate your answer. Just one more general question, do you think that geting familiar with C++/CUDA and Operating Systems can be beneficial to me as researcher/developer in general?

Thaks, Yere.

Yes, I think any computer science knowledge is beneficial for researchers as well and especially understanding how the hardware works could help in isolating bottlenecks etc.

1 Like