Averaging Gradients in DistributedDataParallel

I am a bit confused about averaging gradients in distributed data-parallel. It seems there are two examples from the PyTorch documentation that are different. In one example, you create the model and just pass it to the GPU available then create a separate function to average gradients.

def average_gradients(model):
    """ Gradient averaging. """
    size = float(dist.get_world_size())
    for param in model.parameters():
        dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
        param.grad.data /= size

and its executed as follows.

        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        average_gradients(model)
        optimizer.step()

The other approach I have seen doesn’t create a separate function and just calls DPP.

    model = ToyModel().cuda(device_ids[0])
    ddp_model = DDP(model, device_ids)

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_ids[0])
    loss_fn(outputs, labels).backward()
    optimizer.step()

I would like to know whats the difference between the two approaches and which one should one use for distributed training in a HPC cluster. I specifically want to use two nodes, each with 4 GPUs.

Hey @ankahira, usually, there are 4 steps in distributed data parallel training:

  1. local forward to compute loss
  2. local backward to compute local gradients
  3. allreduce (communication) to compute global gradients. This would be allreduce with SUM + divide by world size to calculate average
  4. optimizer step to use global gradients to update parameters

Both examples you mentioned above conduct the same four steps and are mathematically equivalent. The difference is that DDP would allow step 2 (backward computation) and 3 (allreduce communication) to overlap and therefore DDP is expected to be faster than the average_gradients approach.

More specifically, in the first example with average_gradients, there is a hard barrier between backward and allreduce, i.e., no comm can start before computation finishes. In the second example, DDP organizes gradients into buckets, and will launch comm as soon as a bucket of gradients are ready, so that computation and communication can run in parallel. This would help explain that.

I would recommend DDP. :slight_smile:

3 Likes

Why would we need to divide the gradients by word size in step 3?

Because in many cases, the loss function computes per-sample loss (e.g., default MSELoss) instead of aggregated loss (e.g., default sum). So, DDP calculates the average and tries to keep the gradients scale consistent with local training.