PyTorch DataLoader Very Slow First Batch

We have a popular PyTorch YOLOv3 repository (github.com/ultralytics/yolov3) which uses the dataloader for multithreaded performance. The first batch in each epoch always takes several times longer than the rest of the batches, and we’ve noticed that the dataloader is loading up far more events than necessary in this first batch (for no apparent reason). For example, if we print the index requested by the dataloader, set batch_size = 4, num_workers = 2, we see that the dataloader loads up 20 images in this first batch (!!!), and then 4 every subsequent batch (as expected).

So our question is why in the world is the dataloader loading up 5X the number of required images the first batch? This is clearly the cause of the slowdown, but the reason for this behavior remains a mystery to us. This has severely detrimental timing effects on smaller datasets.

/Users/glennjocher/.conda/envs/yolov3/bin/python /Users/glennjocher/PycharmProjects/yolov3/train.py
Namespace(accumulate=1, backend='nccl', batch_size=4, cfg='cfg/yolov3-spp.cfg', data_cfg='data/coco_32img.data', dist_url='tcp://127.0.0.1:9999', epochs=273, evolve=False, img_size=416, multi_scale=False, nosave=False, notest=False, num_workers=2, rank=0, resume=False, transfer=False, var=0, world_size=1)
Using CPU

Reading images: 100%|██████████| 32/32 [00:00<00:00, 168.00it/s]
Model Summary: 225 layers, 6.29987e+07 parameters, 6.29987e+07 gradients

   Epoch       Batch        xy        wh      conf       cls     total   targets      time
0
4
1
5
2
6
3
7
12
8
9
13
10
14
11
15
16
17
18
19
   0/272         0/7     0.148    0.0947      36.5      1.38      38.1        13      7.93
20
21
22
23
   0/272         1/7     0.177     0.131      36.5      1.75      38.5        20      7.07
24
25
26
27
   0/272         2/7     0.177     0.162      36.5      1.87      38.7        24      6.62
28
29
30
31
...

Do you print the index in your custom dataset class? This might be caused by the dataloader preloading data, which actually speeds up your later batches overall…

Could you print the shape of the imgs tensor in the training loop for i, (imgs, targets, _, _) in enumerate(dataloader):? My guess is that the batch size here should be 4.