Gradient accumulation with DDP no_sync interface

I am currently trying to perform gradient accumulation in a DistributedDataParallel setting for simulating large batch sizes on a small number of GPUs. As far as I know, that should be possible by skipping the AllReduce operation.

Horovod features the backward_passes_per_step option for such cases - I read that the same should be possible in PyTorch-DDP with no_sync interface.

My original batch size is 64, and I try to “simulate” a batch size that is 4x as large (=256).

This is a snippet of my current code:

    for i, (images, target) in enumerate(train_loader):
        images = images.to(device)
        target = target.to(device)

        if (i % 4 == 0):
        # perform regular AllReduce operation every four steps 
            output = model(images)
            loss = criterion(output, target)
            train_loss += loss 

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        else:
        # accumulate gradients using the no_sync interface
           with model.no_sync():
                   output = model(images)
                   loss = criterion(output, target)
                   train_loss += loss 

                   # skipping the optimizer call here 
                   loss.backward()

However, when comparing the training loss of this “simulated” large batch run (with gradient accumulation) with actually running with batch size 256, I get a lower loss for the gradient accumulation run than for the run with the actual batch size 256.

I suspect there are some hidden synchronization steps in my code that occur despite the use of the no_sync interface (and therefore get me a better training loss)? Is that the case here and how could that be fixed?

This looks like correct usage. With regard to the loss difference, have you checked how the gradients differ between no_sync versus larger batch size runs, and ensured that your training is as deterministic as possible?

Also, I think as typical practice, optimizer.step() should be done after gradients are allreduced, otherwise the model may become out of sync across ranks.