Synchronization steps in distributed data parallel

As far as I understood, the DistributedDataParallel module performs gradient synchronization between different nodes automatically, one thing I don’t understand clearly is when this synchronization is done exactly?

For example, the below snippet is from GETTING STARTED WITH DISTRIBUTED DATA PARALLEL PyTorch documentation with small change:

def demo_basic(rank, world_size):
    setup(rank, world_size)

    # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
    # rank 2 uses GPUs [4, 5, 6, 7].
    n = torch.cuda.device_count() // world_size
    device_ids = list(range(rank * n, (rank + 1) * n))

    # create model and move it to device_ids[0]
    model = ToyModel().to(device_ids[0])
    # output_device defaults to device_ids[0]
    ddp_model = DDP(model, device_ids=device_ids)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])

    loss = loss_fn(outputs, labels)
#    loss_fn(outputs, labels).backward()



In the above example, Is computed loss synchronized among all nodes? i.e., Does loss value represents only each node loss or it is averaged among all nodes?

The DDP() wrapper takes care of all the synchronizations and offer a nn.Module like api so that you can use it transparently.

1 Like

Hi, Do you know where the script of gradients synchronization during backward is in pytorch source code?

@meilu_zhu The DDP wrapper creates a c10d.Reducer which is responsible for concatenating multiple gradients into larger buckets and reducing them. You can find the source code at torch/csrc/distributed/c10d/reducer.cpp.

Hi, @pietern. Thanks for your answer. “DistributedDataParallel” automatically averages the gradient when calling loss.backward(), But I didn’t find the corresponding script about how calling loss.backward() triggers torch/csrc/distributed/c10d/reducer.cpp to concatenate multiple gradients in pytorch source code? Could you tell me where it is, please?