What could be the reason of new epoch taking too long time

Hello,

I am running pytorch and found that every epoch, a lot of time passes, as the log below shows:

{"epoch": 0, "step": 80, "lr_weights": 0.004819277108433735, "lr_biases": 0.00011566265060240964, "loss": 8449.654296875, "time": 396}
{"epoch": 0, "step": 81, "lr_weights": 0.004879518072289157, "lr_biases": 0.00011710843373493975, "loss": 9105.5068359375, "time": 397}
{"epoch": 0, "step": 82, "lr_weights": 0.004939759036144579, "lr_biases": 0.00011855421686746987, "loss": 9423.2666015625, "time": 398}
{"epoch": 1, "step": 83, "lr_weights": 0.005000000000000001, "lr_biases": 0.00011999999999999999, "loss": 10010.8828125, "time": 547}
{"epoch": 1, "step": 84, "lr_weights": 0.005060240963855422, "lr_biases": 0.00012144578313253011, "loss": 8652.189453125, "time": 550}
{"epoch": 1, "step": 85, "lr_weights": 0.0051204819277108436, "lr_biases": 0.00012289156626506023, "loss": 9840.9541015625, "time": 552}

I tried to find out why this is the case, with the code below,

def main_worker(gpu, args):
    args.rank += gpu
    torch.distributed.init_process_group(
        backend='nccl', init_method=args.dist_url,
        world_size=args.world_size, rank=args.rank)

    if args.rank == 0:
        args.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        stats_file = open(args.checkpoint_dir / 'stats.txt', 'a', buffering=1)
        print(' '.join(sys.argv))
        print(' '.join(sys.argv), file=stats_file)

    torch.cuda.set_device(gpu)
    torch.backends.cudnn.benchmark = True

    model = BarlowTwins(args).cuda(gpu)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    param_weights = [] #list of weight torch tensors
    param_biases = []  
    for param in model.parameters():
        if param.ndim == 1:            #i.e. bias이면 (since tensor의 dimensino이 1이니)
            param_biases.append(param)
        else:
            param_weights.append(param)
    parameters = [{'params': param_weights}, {'params': param_biases}]
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
    optimizer = LARS(parameters, lr=0, weight_decay=args.weight_decay,
                     weight_decay_filter=True,
                     lars_adaptation_filter=True)

    # automatically resume from checkpoint if it exists
    if (args.checkpoint_dir / 'checkpoint.pth').is_file():
        ckpt = torch.load(args.checkpoint_dir / 'checkpoint.pth',
                          map_location='cpu')
        start_epoch = ckpt['epoch']
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
    else:
        start_epoch = 0
    
    #=========defined new dataset=============#
    data_path = args.data
    dataset = MRI_dataset(data_path, "train",1, transform_yAware_all,(1,99, 117, 95), MNI = True)
    
    sampler = torch.utils.data.distributed.DistributedSampler(dataset)

    per_device_batch_size = args.batch_size // args.world_size
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=per_device_batch_size, num_workers=args.workers,
        pin_memory=True, sampler=sampler, drop_last = True)

    start_time = time.time()
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(start_epoch, args.epochs):
        sampler.set_epoch(epoch)
        for step, ((y1, y2), _) in enumerate(loader, start=epoch * len(loader)):
            y1 = y1.cuda(gpu, non_blocking=True)
            y2 = y2.cuda(gpu, non_blocking=True)
            adjust_learning_rate(args, optimizer, loader, step)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                loss = model.forward(y1, y2)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            if step % args.print_freq == 0:
                if args.rank == 0:
                    stats = dict(epoch=epoch, step=step,
                                 lr_weights=optimizer.param_groups[0]['lr'],
                                 lr_biases=optimizer.param_groups[1]['lr'],
                                 loss=loss.item(),
                                 time=int(time.time() - start_time))
                    print(json.dumps(stats))
                    print(json.dumps(stats), file=stats_file)
        if args.rank == 0:
            # save checkpoint(saves the whole model)
            state = dict(epoch=epoch + 1, model=model.state_dict(),
                         optimizer=optimizer.state_dict())
            torch.save(state, args.checkpoint_dir / 'checkpoint.pth')

            #========ADDED_to_save_along_path==========#
            if epoch%args.save_every ==0:
                torch.save(state, args.checkpoint_dir /'checkpoint_{}.pth'.format(epoch)) 
                torch.save(model.module.backbone.state_dict(),
                   args.checkpoint_dir / f'resnet50_{epoch}.pth')
            #==========================================#
    if args.rank == 0:
        # save final model(saves only the CNN part)
        torch.save(model.module.backbone.state_dict(),
                   args.checkpoint_dir / 'resnet50.pth')

I thought that since the delay occurs when there the step loop is done, I thought I might disable the saving part to see if that speeds things up, like below :

        #if args.rank == 0:
            # save checkpoint(saves the whole model)
            #state = dict(epoch=epoch + 1, model=model.state_dict(),
                         #optimizer=optimizer.state_dict())
            #torch.save(state, args.checkpoint_dir / 'checkpoint.pth')

            #========ADDED_to_save_along_path==========#
            #if epoch%args.save_every ==0:
            #    torch.save(state, args.checkpoint_dir /'checkpoint_{}.pth'.format(epoch)) 
            #    torch.save(model.module.backbone.state_dict(),
            #       args.checkpoint_dir / f'resnet50_{epoch}.pth')
            #==========================================#
    if args.rank == 0:
        # save final model(saves only the CNN part)
        torch.save(model.module.backbone.state_dict(),
                   args.checkpoint_dir / 'resnet50.pth')


However, the same thing still occurred like below. Is there something that could be done to remedy this?

{"epoch": 0, "step": 80, "lr_weights": 0.004819277108433735, "lr_biases": 0.00011566265060240964, "loss": 9116.9658203125, "time": 839}
{"epoch": 0, "step": 81, "lr_weights": 0.004879518072289157, "lr_biases": 0.00011710843373493975, "loss": 10107.34765625, "time": 840}
{"epoch": 0, "step": 82, "lr_weights": 0.004939759036144579, "lr_biases": 0.00011855421686746987, "loss": 9236.630859375, "time": 841}
{"epoch": 1, "step": 83, "lr_weights": 0.005000000000000001, "lr_biases": 0.00011999999999999999, "loss": 8677.29296875, "time": 1079}
{"epoch": 1, "step": 84, "lr_weights": 0.005060240963855422, "lr_biases": 0.00012144578313253011, "loss": 9647.3740234375, "time": 1092}
{"epoch": 1, "step": 85, "lr_weights": 0.0051204819277108436, "lr_biases": 0.00012289156626506023, "loss": 9818.392578125, "time": 1180}
{"epoch": 1, "step": 86, "lr_weights": 0.005180722891566266, "lr_biases": 0.00012433734939759037, "loss": 9051.8212890625, "time": 1265}

update

after looking at each line and looking at the time with the modified code snippet below, I found that considerable time in spent during initializing the for loop for the dataloader,

  start_time = time.time()
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(start_epoch, args.epochs):
        print(f"rank {args.rank} set epoch start : {time.time() - start_time}")
        sampler.set_epoch(epoch)
        print(f"rank {args.rank} set epoch end : {time.time()- start_time}")
        for step, ((y1, y2), _) in enumerate(loader, start=epoch * len(loader)):
            print(f"rank {args.rank} step start : {time.time()- start_time}")
            y1 = y1.cuda(gpu, non_blocking=True)
            y2 = y2.cuda(gpu, non_blocking=True)

When I tried with one GPU (i.e. disabling DDP) I got,

rank 0 set epoch start : 4.7206878662109375e-05
rank 0 set epoch end : 8.702278137207031e-05
rank 0 step start : 72.13791012763977
{"epoch": 0, "step": 0, "lr_weights": 0.0, "lr_biases": 0.0, "loss": 13827.8291015625, "time": 82}
rank 0 step start : 82.55937671661377
{"epoch": 0, "step": 1, "lr_weights": 3.765060240963856e-06, "lr_biases": 9.036144578313253e-08, "loss": 12604.126953125, "time": 83}
rank 0 step start : 83.55196452140808

which indicated that when the dataloader for loop is initialized considerable time is spent

The same was true for DDP (two gpus)

rank 0 set epoch start : 5.91278076171875e-05
rank 0 set epoch end : 0.00011086463928222656
rank 1 set epoch start : 8.869171142578125e-05
rank 1 set epoch end : 0.00013375282287597656
rank 1 step start : 67.45242071151733
rank 0 step start : 69.03039646148682
{"epoch": 0, "step": 0, "lr_weights": 0.0, "lr_biases": 0.0, "loss": 12499.767578125, "time": 80}
rank 1 step start : 80.62654376029968
rank 0 step start : 80.63264775276184
{"epoch": 0, "step": 1, "lr_weights": 1.5060240963855424e-05, "lr_biases": 3.614457831325301e-07, "loss": 9944.970703125, "time": 81}
rank 0 step start : 81.4628357887268
rank 1 step start : 81.45779013633728
{"epoch": 0, "step": 2, "lr_weights": 3.012048192771085e-05, "lr_biases": 7.228915662650602e-07, "loss": 11979.03125, "time": 82}
rank 1 step start : 82.26426815986633
rank 0 step start : 82.26971387863159

Is this normal? What could be causing the dataloader for loop initialization to take such long times?