Slow data loading - every couple of batch

Hi
I am trying to use torchrun --nproc_per_node=8 to train SimCLR on ImageNet using 8 GPUs in parallel. I am using this command to distribute the model

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], bucket_cap_mb=200)
The problem is that every couple of batches the data loading is very slow. This is not only for the first batch, which I guess is normal.
Here I paste the for loop to train and print data and optimization times.

 end = time.time()
    for data_iter, inputs in enumerate(train_loader):
        optim_iter = data_iter // args.update_freq
        # measure data loading time
        data_time.update(time.time() - end)
        data_time_ = time.time() - end

        # update weight decay and learning rate according to their schedule
        it = iters_per_epoch * epoch + optim_iter  # global training iteration
        for k, param_group in enumerate(optimizer.param_groups):
            param_group['lr'] = lr_schedule[it]

        inputs = [tensor.cuda(args.gpu, non_blocking=True) for tensor in inputs]

        # compute output
        with amp.autocast(enabled=not args.disable_amp):
            outputs = model(*inputs)
            loss_dict = criterion(outputs)
            loss = loss_dict['loss']
            loss /= args.update_freq

        scaler.scale(loss).backward()

        if (data_iter + 1) % args.update_freq != 0:
            continue

        # compute gradient and do SGD step
        scaler.step(optimizer)
        scaler.update()
        model.zero_grad(set_to_none=True)


        # clamp logit scale to [0, 100]
        if args.model.startswith('SIMCLR'):
            logit_scale = 0
        else:
            utils.get_model(model).logit_scale.data.clamp_(0, 4.6052)
            logit_scale = utils.get_model(model).logit_scale.exp().item()

        for k in loss_dict:
            metrics[k].update(loss_dict[k].item(), args.batch_size)

        # measure elapsed time
        batch_time.update(time.time() - end)
        batch_time_ = time.time() - end
        end = time.time()

        print(data_time_, batch_time_)

Here is the (data_time_, batch_time_) when j=8

142.89722776412964 172.6102113723755
0.0011818408966064453 1.149657964706421
0.0005125999450683594 0.5908513069152832
0.000789642333984375 0.5789885520935059
0.0006721019744873047 0.5847604274749756
0.0006527900695800781 0.5850253105163574
0.0006945133209228516 0.5916690826416016
0.0006604194641113281 0.5858132839202881
33.176689863204956 113.46687459945679
0.0006616115570068359 3.237961530685425
0.0004947185516357422 0.5849909782409668
0.0006363391876220703 0.5738029479980469
0.0005011558532714844 0.5823519229888916
0.0004696846008300781 0.5900559425354004
0.0006518363952636719 0.5800015926361084
0.0006239414215087891 0.5876047611236572
0.0006771087646484375 70.80163884162903
0.0009722709655761719 0.5865006446838379
0.0006647109985351562 0.5858757495880127
0.0006537437438964844 1.4045929908752441
0.0006687641143798828 32.53645730018616
0.0008571147918701172 0.5862514972686768
0.0006554126739501953 0.5853633880615234
0.0006177425384521484 16.638994216918945
0.0009629726409912109 66.59664487838745
0.001010894775390625 0.588956356048584
0.0006909370422363281 0.5857172012329102
0.0006442070007324219 0.5854201316833496
0.0006425380706787109 29.423442363739014
0.0008411407470703125 0.5904080867767334
0.0007281303405761719 0.5878152847290039
0.0007319450378417969 47.11639881134033
0.0008819103240966797 23.894486904144287
0.0007138252258300781 0.578960657119751
0.0005040168762207031 0.5796892642974854
0.0004954338073730469 0.5855348110198975
0.0004782676696777344 66.34711623191833
0.0006091594696044922 0.5864002704620361
0.0005247592926025391 0.5784976482391357
0.0005164146423339844 0.6909780502319336
0.0006034374237060547 26.061028718948364
0.0008776187896728516 0.5903408527374268
0.0006973743438720703 0.584754467010498
0.0006537437438964844 0.5849916934967041

36.196861028671265 36.79169154167175
0.0008547306060791016 0.5894830226898193
0.0008087158203125 0.5778903961181641
0.0006210803985595703 0.5889377593994141
0.0008003711700439453 0.5878915786743164
0.00077056884765625 0.5870087146759033
0.0007855892181396484 93.86848998069763
0.0007998943328857422 0.5807287693023682
17.311529874801636 17.90171504020691
0.0007803440093994141 9.284274816513062
0.0008406639099121094 0.5794563293457031
0.0008251667022705078 0.6089217662811279
0.00078582763671875 0.5598442554473877
0.0007565021514892578 0.5864059925079346
0.0007340908050537109 42.826006174087524
0.0010673999786376953 0.5904500484466553
23.019705295562744 59.32295536994934
0.0007565021514892578 31.347289085388184
0.0006775856018066406 0.5731685161590576
0.0007195472717285156 0.5763015747070312
0.0005919933319091797 0.5776708126068115
0.0005700588226318359 0.5778248310089111
0.0006148815155029297 7.108304738998413
0.0005848407745361328 0.5788106918334961
0.0006554126739501953 32.21546387672424
0.0007257461547851562 88.52377581596375
0.0008158683776855469 0.5769295692443848

I also noticed that although GPUs are 100% utilized but their power usage are around 80/350 W.

i wonder if you are simply bottlenecked by data loading.
Can you try to switch the data loading to something like FFCV and see if the issue persists: GitHub - libffcv/ffcv: FFCV: Fast Forward Computer Vision (and other ML workloads!)