If I understood correctly, DDP wrapper on my model synchronizes the gradients at every backward pass. This seems to cause a massive bottleneck performance due to communication between gpus at every minibatch.
Because I am accumulating gradients over multiple mini-batches, I would like the gradient synchronisation to happen at every optimizer step instead of minibatch backward pass, which would massively speed up my training and be theoretically equivalent.
Questions are:
- is my understanding of DDP correct?
- If so, do you have any hint on how to sync gradients during backward pass only?
Thanks in advance.
EDIT: Is this claude solution the best way to go about it?
# Setup
model = YourModel().to(device)
model = DDP(model, device_ids=[local_rank], find_unused_parameters=False)
optimizer = torch.optim.Adam(model.parameters())
# Hyperparameters
accumulation_steps = 4 # Number of batches to accumulate before synchronizing
optimizer.zero_grad()
for epoch in range(num_epochs):
for i, (inputs, targets) in enumerate(dataloader):
# Forward pass
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# Scale loss by accumulation steps
loss = loss / accumulation_steps
# Backward pass - DDP will NOT sync at this point if we disable it
with model.no_sync(): # This is the key part - disables synchronization
loss.backward()
# Only perform synchronization and optimizer step after accumulation_steps
if (i + 1) % accumulation_steps == 0:
# Re-enable synchronization for the final backward pass
# Or perform synchronization manually
for param in model.parameters():
if param.requires_grad and param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data = param.grad.data / world_size
# Now update parameters
optimizer.step()
optimizer.zero_grad()