Data loader takes a lot of time for every nth iteration

Hi Pytorchers,

My DataLoader takes a lot of time for every nth iteration and a fraction of second for all other iterations. What could be the reason? I am using 8 threads for data reading with a batch_size of 128. Here is the log (numbers in the paranthesis are running averages):

Iter: 0 Epoch: [0/40][0/706] Time: 25.277 (25.277) Loss: 2.1912 (2.1912) lr: 0.0300
Iter: 1 Epoch: [0/40][1/706] Time: 0.732 (13.005) Loss: 2.1298 (2.1605) lr: 0.0300
Iter: 2 Epoch: [0/40][2/706] Time: 0.040 (8.683) Loss: 1.9594 (2.0935) lr: 0.0300
Iter: 3 Epoch: [0/40][3/706] Time: 0.012 (6.515) Loss: 1.8457 (2.0315) lr: 0.0300
Iter: 4 Epoch: [0/40][4/706] Time: 0.001 (5.212) Loss: 1.8034 (1.9859) lr: 0.0300
Iter: 5 Epoch: [0/40][5/706] Time: 0.002 (4.344) Loss: 1.9047 (1.9724) lr: 0.0300
Iter: 6 Epoch: [0/40][6/706] Time: 0.001 (3.724) Loss: 1.9777 (1.9731) lr: 0.0300
Iter: 7 Epoch: [0/40][7/706] Time: 0.001 (3.258) Loss: 1.9933 (1.9757) lr: 0.0300
Iter: 8 Epoch: [0/40][8/706] Time: 25.211 (5.697) Loss: 1.9172 (1.9692) lr: 0.0300
Iter: 9 Epoch: [0/40][9/706] Time: 0.677 (5.195) Loss: 1.8515 (1.9574) lr: 0.0300
Iter: 10 Epoch: [0/40][10/706] Time: 0.001 (4.723) Loss: 1.7753 (1.9408) lr: 0.0300
Iter: 11 Epoch: [0/40][11/706] Time: 0.291 (4.354) Loss: 1.7470 (1.9247) lr: 0.0300
Iter: 12 Epoch: [0/40][12/706] Time: 0.001 (4.019) Loss: 1.7333 (1.9100) lr: 0.0300
Iter: 13 Epoch: [0/40][13/706] Time: 0.487 (3.767) Loss: 1.7477 (1.8984) lr: 0.0300
Iter: 14 Epoch: [0/40][14/706] Time: 0.001 (3.516) Loss: 1.7632 (1.8894) lr: 0.0300
Iter: 15 Epoch: [0/40][15/706] Time: 0.104 (3.302) Loss: 1.7588 (1.8812) lr: 0.0300

Thank you in advance for your help.

6 Likes

The loading might be too slow. Have you tried to increase the number of workers or decrease the batch size?
Are you working on a HDD, SSD or do you pull your data from a server in your network?

1 Like

I am pulling the data from an NFS server. But I am wondering why it impacts every nth batch only? I tried to reduce the batch size to 64, it reduces the time but the pattern is still the same. Every nth batch takes much more time as compared to other batches.

Well, as you can see, every 8th batch is slower while you are using 8 workers. Apparently it takes each worker more than 20 seconds to pull 128 samples from the server, while each training step is much faster. Probably you can also see the GPU starving in the meantime (nvidia-smi shows zero workload).
If it’s possible, save the data locally. I doubt increasing the workers will give you the desired performance boost. The connection seems to be just too slow for that.

Do you have a lot of preprocessing to do?

3 Likes

Also, what kind of data are you using? If it’s image data and you are resizing the images as a pre processing step, you could also try to resize the images on the server and pull these smaller images.

3 Likes

Hi @Umar_Iqbal !
Have you solve this problem?
The GPU is almost not working because of this. It’s really a waste of time.

@JiangPQ In my case, it turned out that reading the data from the NFS server was the main bottleneck. Moving the data to local hard-drive really improved the speed. How’s it in your case, how many workers are you using, and how are you reading your data, any augmentations?

@ptrblck Yes, I have a lot of processing to do, but reading data from the server was the main bottleneck, though the GPU is still starving but not as much as it was before. Saving the augmented images on disk would take a lot of space, I guess.

In my case, the ImageNet dataset is on my local HDD. I use torchvision.datasets.ImageFolder and torch.utils.data.DataLoader with batch_size=32 shuffle=True num_workers=16 to load images. Only randomcrop is applied on the images for preprocessing.

The extra time T it costs for every Nth iteration is directly proportional to the thread number N.
T is about 15s or more when batch_size=32 and num_workers=16.

That’s a lot. Actually for a batch_size=32, num_workers=16 seem to be quite big. Have you tried any lower number of workers? say num_workers=4 or 8.

Actually, I’ve tested the time consumption of DataLoader.


pic1: data loading time consumption at every iteration


pic2: accumulation of pic1

Blue for 1 worker, and yellow for 8, green for 16, red for 32.
I only tested 128 steps with batch_size=32

Multithreading does help, but it’s unreasonable to take 1.5s on average for loading one picture.

Here’s my code:

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
])

train_dir = '...'
train = datasets.ImageFolder(train_dir, transform)

dataloader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True, num_workers=32)

t_sum = 0
last_time = time.time()

for i, data in enumerate(dataloader):
    t = time.time()
    t_diff = t - last_time
    t_sum += t_diff
    draw_function(t_diff, t_sum)
    last_time = t

    if i >= 127:
        break
2 Likes

How about replacing ResizedCrop with RandomCrop (just for debugging)? Are your images super big? Pytorch uses PIL for all image processing operations which is not the fastest possible option. You can write your own transformation that uses OpenCV.

@JiangPQ some notes:

3 Likes

Thank you for your advises!
The sequential-imagenet-dataloader improved the data reading speed a lot. Although it took about 10h to generate the .lmdb file, it’s worth it.

Will pytorch include it? This helps people without high performance computers.

@smth’s advises really help.
It’s worth a try!

I notice pin_memory = True helps

I am facing same problem when working ImageNet data, things are alright when I work with other datasets like Pascal VOC. Volatile gpu util is also showing 0% most of the time. Any solution for this problem? 13%20AM

I resolved this issue by switching to SSD-M2

what is the optimal num_worker for different batch_size?

@JiangPQ @sumanth9 This likely means the training pipeline is bottlenecked by data loading time, as shown in the following animation:

The ‘nth iteration problem’ probably just happens because the data loading time has low variance between workers. If the data loading is efficient enough, this problem should naturally disappear.

@pen_good The optimal num_worker/batch_size will depend on systems and augmentation pipelines. For example, if you are not using storage that supports parallel IO, you can’t benefit much from increasing num_workers since you will be bound by IO. Also, if your augmentation pipeline involves lots of CPU-workload, it could make sense to use more num_workers than the number you use in other pipelines.