Hello! Reading the docs, torch’s DistributedDataParallel.no_sync should (in theory) disable grad sync in the backward pass. But surprisingly, it seems that it’s complicated to make it work for multiple forward passes followed by multiple backward passes
Here’s the example from the docs (uses forward-backward)
ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
with ddp.no_sync():
ddp(input).backward() # no synchronization, accumulate grads
ddp(another_input).backward() # synchronize grads
Let’s separate forward from backward
ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
with ddp.no_sync(): # forward under no_sync -> works
loss = ddp(input)
loss.backward() # no synchronization, accumulate grads
ddp(another_input).backward() # synchronize grads
ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
loss = ddp(input)
with ddp.no_sync(): # backward under no_sync -> doesn't work
loss.backward() # synchronize grads
ddp(another_input).backward() # synchronize grads
# This means that no_sync doesn't affect backward() calls, only forward() calls
So here’s my question: what’s the correct way of doing this without hacking into reducer.prepare_for_backward
?
ddp = torch.nn.parallel.DistributedDataParallel(model, pg)
with ddp.no_sync():
loss_1 = ddp(input)
loss_2 = ddp(another_input)
loss_1.backward() # synchronizes grad, but I expect it to no sync (how to do this?)
loss_2.backward() # synchronize grads