DistributedDataParallel.no_sync with multiple forwards

Hello! Reading the docs, torch’s DistributedDataParallel.no_sync should (in theory) disable grad sync in the backward pass. But surprisingly, it seems that it’s complicated to make it work for multiple forward passes followed by multiple backward passes

Here’s the example from the docs (uses forward-backward)

ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
with ddp.no_sync():
    ddp(input).backward()  # no synchronization, accumulate grads
ddp(another_input).backward()  # synchronize grads

Let’s separate forward from backward


ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
with ddp.no_sync(): # forward under no_sync -> works
    loss = ddp(input) 
loss.backward() # no synchronization, accumulate grads
ddp(another_input).backward()  # synchronize grads

ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
loss = ddp(input) 
with ddp.no_sync(): # backward under no_sync -> doesn't work
    loss.backward() # synchronize grads
ddp(another_input).backward()  # synchronize grads
# This means that no_sync doesn't affect backward() calls, only forward() calls

So here’s my question: what’s the correct way of doing this without hacking into reducer.prepare_for_backward ?

ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
with ddp.no_sync():
    loss_1 = ddp(input) 
loss_2 = ddp(another_input)
loss_1.backward() # synchronizes grad, but I expect it to no sync (how to do this?)
loss_2.backward()  # synchronize grads

This might help… What's no_sync() exactly do in DDP

In the source code, not the help document, the source code gives vague instructions that the forward and the backward must be done under the no_sync. The example I pasted above shows how things work correctly, and how they don’t.

This article on huggingface explains how no_sync works with forward and backward code example. Gradient Synchronization

Hope this helps you. Took me a while to figure it out.

Other things will happen when you get it working smoothly. Like… I use torchrun and ddp. I have some early stop code. One of my 4 GPUs did the math and decided it should early exit. The other 3 GPUs sat with memory full and 100% usage, but the program wasn’t moving forward. I assume it can be solved by telling torchrun to use [1-4] GPUs instead of 4. I think torchrun is thinking “hey I need to stop because I must have 4 GPUs”. I’ll also review the early stop code to see if it can be improved.

I’m not exactly sure how I can prove the models on the 4 GPUs update their weights at the end of the epoch (backward without the no_sync) and “combine or average their learning of the best weights”. I’ll just assume it happened. If I am saving the best model so far as my program progresses, does it matter which GPU saves the model ? What if one GPU is 5% higher accuracy v another? Should I save the highest out of all the GPUs ? But if they update each other eventually then does it really matter?

1 Like