DDP and Gradient Sync

I am training a model with DDP and found that my gradients were not being synced after loss.backward() call. After debugging, I found that the reason was because I was manually turning off the gradients of parameters I did not want updated. This somehow prevents the synchronization of the param.grad fields of parameters that I care about. I wonder if this is the intended behavior and if a warning should be thrown because otherwise the training proceeded with no errors.

Interestingly, when I turned on NCCL logs, I saw allreduce operations being logged. I’m hypothesizing that those are related to the allreduce operations for each bucket which precedes the final allreduce that writes to param.grad.

A small toy example to showcase the issue is provided below.

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

import os
def ddp_setup():
        dist.init_process_group(backend="nccl")
        rank = os.environ["LOCAL_RANK"]
        torch.cuda.set_device(int(rank))

class SimpleModel(nn.Module):
        def __init__(self, n_layers):
                super().__init__()
                self.layers = nn.ModuleList([nn.Linear(2, 2) for _ in range(n_layers)])
        def forward(self, x):
                output = []
                for layer in self.layers:
                        x = layer(x)
                        output.append(x)

                output = torch.stack(output, dim = 0)
                return output

def main():
        ddp_setup()
        rank = dist.get_rank()

        torch.manual_seed(420)

        # Create two layer linear model and wrap with DDP
        model = SimpleModel(2)
        model = DDP(model.to(rank), device_ids=[rank], output_device=rank)

        # Turn off gradients for later layers because I only need output from first
        for name, param in model.named_parameters():
                if "1" in name:
                        param.requires_grad = False

        # Pass dummy data through and compute loss from first layer
        inp = torch.rand(2) + rank
        out = model(inp)
        loss = out[0].sum()
        loss.backward()

        # Log unsynced gradients
        for name, param in model.named_parameters():
                if "bias" not in name:
                        print(f"RANK: {rank}, Param: {name}, Grad: {param.grad}")

which I run by

torchrun --standalone --nproc-per-node 2 script.py

with output

RANK: 1, Param: module.layers.0.weight, Grad: tensor([[1.2448, 1.8644],[1.2448, 1.8644]], device='cuda:1')
RANK: 1, Param: module.layers.1.weight, Grad: None
RANK: 0, Param: module.layers.0.weight, Grad: tensor([[0.2448, 0.8644], [0.2448, 0.8644]], device='cuda:0') 
RANK: 0, Param: module.layers.1.weight, Grad: None

If I do not turn the param.grad to False, the gradients are synced as expected.

Run a second iteration and the code will properly fail:

[rank1]: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by 
[rank1]: making sure all `forward` function outputs participate in calculating loss. 
[rank1]: If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
[rank1]: Parameter indices which did not receive grad for rank 1: 2 3
[rank1]:  In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error

Interesting, in my original code, there was no error thrown when training for multiple iterations. I’ll see if I can replicate it on a toy example.

In my original implementation, when I wrapped my model with DDP, I did not have the learnable adapter active, so DDP did not register the correct parameters to keep track of during training. When I had fixed this issue, I was only testing one batch at a time to check if gradients synced. Thanks for the help!