Synchronization steps in distributed data parallel

As far as I understood, the DistributedDataParallel module performs gradient synchronization between different nodes automatically, one thing I don’t understand clearly is when this synchronization is done exactly?

For example, the below snippet is from GETTING STARTED WITH DISTRIBUTED DATA PARALLEL PyTorch documentation with small change:

def demo_basic(rank, world_size):
    setup(rank, world_size)

    # setup devices for this process, rank 1 uses GPUs [0, 1, 2, 3] and
    # rank 2 uses GPUs [4, 5, 6, 7].
    n = torch.cuda.device_count() // world_size
    device_ids = list(range(rank * n, (rank + 1) * n))

    # create model and move it to device_ids[0]
    model = ToyModel().to(device_ids[0])
    # output_device defaults to device_ids[0]
    ddp_model = DDP(model, device_ids=device_ids)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])

    loss = loss_fn(outputs, labels)
    loss.backward()
#    loss_fn(outputs, labels).backward()

    optimizer.step()

    cleanup()

In the above example, Is computed loss synchronized among all nodes? i.e., Does loss value represents only each node loss or it is averaged among all nodes?

The DDP() wrapper takes care of all the synchronizations and offer a nn.Module like api so that you can use it transparently.

2 Likes

Hi, Do you know where the script of gradients synchronization during backward is in pytorch source code?

@meilu_zhu The DDP wrapper creates a c10d.Reducer which is responsible for concatenating multiple gradients into larger buckets and reducing them. You can find the source code at torch/csrc/distributed/c10d/reducer.cpp.

Hi, @pietern. Thanks for your answer. “DistributedDataParallel” automatically averages the gradient when calling loss.backward(), But I didn’t find the corresponding script about how calling loss.backward() triggers torch/csrc/distributed/c10d/reducer.cpp to concatenate multiple gradients in pytorch source code? Could you tell me where it is, please?

In https://pytorch.org/tutorials/intermediate/ddp_tutorial.html, the code demo_checkpoint really confused me. There are two main quetions:
(1) In demo_checkpoint, all processes need to be synchronized by loading the same checkponts saved by process-0. If it is necessary when we train model with DistributedDataParallel.?
(2) If it is necessary to synchronize model across multi nodes and gpus by loading the same checkpoint, how can node-B load checkpoints saved in node-A?
I don’t know if I understand demo_checkpoint in a right way. Could you please help answer this question? Thanks !

@hhxx

(1) In demo_checkpoint , all processes need to be synchronized by loading the same checkponts saved by process-0. If it is necessary when we train model with DistributedDataParallel.?

No, this is not necessary. This is only useful when your training job take very long and can crash in the middle. You can then use the checkpoint to recover instead of starting over from scratch.

(2) If it is necessary to synchronize model across multi nodes and gpus by loading the same checkpoint, how can node-B load checkpoints saved in node-A?
I don’t know if I understand demo_checkpoint in a right way. Could you please help answer this question? Thanks !

The recovery scheme should be application-dependent. That tutorial demonstrates single-machine multi-GPU DDP with checkpointing. So, all DDP processes can read from the same file. If you need checkpoint, and if your training spans multiple machines, you can load it from rank0 and then broadcast it to other ranks using torch.distributed.braodcast.

BTW, we probably should call that tutorial “intermediate”, and use this one as a starting example.

Thanks very much! It is very clear.
Another question is that when we train a model with DistributedDataparallel, DistributedSampler is suggested to use together. If we need add sampler.set_epoch(epoch) before each epoch start according to set_epoch for DistributedSampler.

cc @vincentqb on DistributedSampler question :slight_smile: