DDP with Gradient accumulation and clip grad norm

Hello,

I am trying to do gradient accumulation

model.zero_grad()                                   # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)                     # Forward pass
    loss = loss_function(predictions, labels)       # Compute loss function
    loss = loss / accumulation_steps                # Normalize our loss (if averaged)
    loss.backward()                                 # Backward pass
    if (i+1) % accumulation_steps == 0:             # Wait for several backward steps
        optimizer.step()                            # Now we can do an optimizer step
        model.zero_grad()                           # Reset gradients tensors
        if (i+1) % evaluation_steps == 0:           # Evaluate the model when we...
            evaluate_model()   

as mentioned here

But my model is on multiple GPUs

So when my loss is output from the model how do I scale it? also I want to clip grad norms while training it. How should my training loop modify?

My code is this, does this looks correct?

for i, batch in enumerate(train_loader_with_distributed_sampler):

    for param_group in optimizer.param_groups:
        param_group['lr'] = learning_rate

    x, y = batch
    y_hat = model(x)
    loss = criterion(y_hat, y).mean()

    loss = loss / hparams.gradient_accumulation_step

    reduced_loss = reduce_tensor(loss.data, n_gpus).item()

    loss.backward()

    current_accumulation_run = iteration % hparams.gradient_accumulation_step + 1

    # my grad clip thres is 1.0 so it will be multiplied with 1, 2, 3, 4, 5 based on my gradient accumulation step size
    # Or maybe I don't need to manage this? 
    grad_norm = torch.nn.utils.clip_grad_norm_(
        model.parameters(), hparams.grad_clip_thresh * current_accumulation_run)

    grad_norm = grad_norm

    if (i + 1) % hparams.gradient_accumulation_step == 0:
        optimizer.step()
        model.zero_grad()

    if rank == 0:
        print("Optimizing Step")
        print("Train loss {} {:.6f} Grad Norm {:.6f}".format(
            i, reduced_loss, grad_norm))

You mean your gradients are on multiple GPUs, so that you cannot directly pass it to the same operator? @ptrblck @albanD does this mean the gradients need to be move to the same device first, and then calculate the scaling ratio? Or is the recommended way to do per device/model-shard scaling?

I think I misunderstood in the above comment. Would I be correct if I assume each model replica is on one GPU, and then you have DDP on top for gradient averaging?

If you have single-GPU model replica + DDP, will it be acceptable to let DDP first do gradient averaging, and then do gradient scaling/clipping independently on every process before calling optimizer.step(). Since DDP will make sure that all model replicas have the same gradient, their should reach the same scaling/clipping result.

Another thing is that, to accumulate gradients from multiple iterations, you can try using the ddp.no_sync(), which can help avoid unnecessary communication overheads.

I think this is the default behaviour of DDP no? Since it is initialised with distributed.init, and triggered by the different script where I have set the world size and everything as mentioned in the DDP basic tutorial. So let’s assume I have 4 GPUs my each model is replicated on each GPU and distributed sampler is providing each GPU with equal number of batches.

Yep, it’s the default behavior. Just wanna make sure this is how DDP is used in your application so that my comment can be relevant.

Could you please elaborate more on your requirements for gradient clipping? Do you need to do that before DDP gradient averaging comm or after?

If you need to before DDP comm, then you probably can use the DDP communication hook. More specifically, you can wrap the gradient bucket clipping with the allreduce communication in the hook.

If it is OK to do clipping after DDP comm, then you can run the clipping ops after DDP backward() and before optimizer step.

1 Like

I have no specific requirements actually, I just want to prevent gradient explosion in my implementation, so I was clipping the norm, but then I switched to DDP module and was accumulating the gradients as well and was not sure where to put the clipping now.