DDP - sync gradients during optim step instead of backward

If I understood correctly, DDP wrapper on my model synchronizes the gradients at every backward pass. This seems to cause a massive bottleneck performance due to communication between gpus at every minibatch.
Because I am accumulating gradients over multiple mini-batches, I would like the gradient synchronisation to happen at every optimizer step instead of minibatch backward pass, which would massively speed up my training and be theoretically equivalent.

Questions are:

  1. is my understanding of DDP correct?
  2. If so, do you have any hint on how to sync gradients during backward pass only?
    Thanks in advance.

EDIT: Is this claude solution the best way to go about it?

# Setup
model = YourModel().to(device)
model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
optimizer = torch.optim.Adam(model.parameters())

# Hyperparameters
accumulation_steps = 4  # Number of batches to accumulate before synchronizing
optimizer.zero_grad()

for epoch in range(num_epochs):
    for i, (inputs, targets) in enumerate(dataloader):
        # Forward pass
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        
        # Scale loss by accumulation steps
        loss = loss / accumulation_steps
        
        # Backward pass - DDP will NOT sync at this point if we disable it
        with model.no_sync():  # This is the key part - disables synchronization
            loss.backward()
        
        # Only perform synchronization and optimizer step after accumulation_steps
        if (i + 1) % accumulation_steps == 0:
            # Re-enable synchronization for the final backward pass
            # Or perform synchronization manually
            for param in model.parameters():
                if param.requires_grad and param.grad is not None:
                    dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
                    param.grad.data = param.grad.data / world_size
            
            # Now update parameters
            optimizer.step()
            optimizer.zero_grad()

You could use the no_sync context manager to delay the gradient synchronization and avoid it during gradient accumulation.