When training separate models on a few GPUs on the same machines, we run into a significant training slowdown that is proving difficult to isolate. Can anyone suggest what may be causing this slowdown?
We have a machine with 4 GPUs Nvidia 3090 and AMD Ryzen 3960X. We are running multiple instances of a model to optimize training hyperparameters. A typical epoch training time is 120 minutes and works fine for 3 GPUs in parallel. However, if we train 4 models, training slows down to 200-300 minutes for each of the models starting with the second epoch. We also noticed that when we increase batch size from 2 to 3 (we have huge images!), the number of models that can be trained in parallel without a slowdown goes down to 2.
I’ve tried investigating the following:
- Overheating - temperatures occasionally go up to 95 C but come down quickly and slowdown happens even on GPUs that are at 75 C.
- SSD bottleneck - loading is only a small part of the overall time, and I can read the same input image using 24 CPU threads without a noticeable slowdown, so I don’t think this is the cause.
- CPU bottleneck - only 9 of 24 cores are used, and the slowdown is only after the first epoch, so… no?
- Some sort of memory leak? - GPU memory doesn’t increase progressively
- Code issue? - We are doing the experiments on this repo GitHub - MIC-DKFZ/medicaldetectiontoolkit: The Medical Detection Toolkit contains 2D + 3D implementations of prevalent object detectors such as Mask R-CNN, Retina Net, Retina U-Net, as well as a training and inference framework focused on dealing with medical images. . At the end of the epoch, it checks the optimizer for loss values and decides whether the current checkpoint is part of the top n checkpoints, then cleans up an old one and saves the current. The entire process takes <1 second and nothing seems to be loaded into memory, so I am not sure how it could be the cause, but the slowdown at the epoch end coincides with this step.
Can anyone suggest where to look to fix this?
Profile your workloads to properly isolate the bottleneck. I guess it’s either the data loading or your CPU is not fast enough to schedule all the workloads which should be visible as whitespaces between kernel launches in a visual profile.
Here is a snapshot of Pycharm’s profiling. Upgrading to torch 2.0 to enable the torch profiler has led to new out of memory crashes, so that is a work in progress.
It looks like a bottleneck is in loading batches. The sleep step is part of thread synchup in the dataloader GitHub - MIC-DKFZ/batchgenerators: A framework for data augmentation for 2D and 3D image classification and segmentation. We are using 3 batches * 2 images with a total of about 200 MB. The slowdown happens when we train 3 models in parallel and takes about 900 batches to materialize (i.e. training is fast at first even with 3 models but slows down eventually).
Does this indicate a data reading bottleneck? I haven’t been able to reproduce it with reading one of the images with 24 threads 100 times - the speed is only slightly slower than with a single thread.
If data reading is the problem, is there something we can try other than getting a faster SSD, or to make sure that that would solve the issue? Thanks!
I believe the issue is related to RAM. When I run 3 models, system RAM (all 256 GB of it) gets filled up after several epochs. SWAP gets filled up even with 1 model but doesn’t lead to a slowdown which surprises me.
Exactly what is filling the RAM is also strange:
- The network predicts on an ensemble of n best checkpoints. There is a step that that looks at the metrics dictionary to see if the newest checkpoint is in the top 5 and, if so, saves it to disk deleting the previous. No loading of checkpoint is done. If I comment out this step, slowdown doesn’t happen. So, I was thinking that the losses dictionary is saving the whole graph (based on your answer in another thread!) and anytime there’s a new best checkpoint, it gets added to RAM while the original is being stored in the loss dictionary.
- However, if I reduce the number of batches per epoch from 1200 to 30, the slowdown does not happen after the first several epochs (which new checkpoints are added). Instead, it happens when around 900-1200 batches have been processed which points to the data loader not clearing images. But, I don’t see how this would be consistent with 1). Also, I’ve not seen memory clearing in a data loader before.
The actual code of the checkpoint picker code is pretty simple: