DDP not syncing graidents when trying to do two backward passes

Hi,

I am a bit new to DDP but I couldn’t find any answers for the query I was searching for. I am trying to do two backward passes through my network and my hope is that DDP will sync the gradients of the two backward passes but somehow this is not happening. Following is a minimalistic code of this behavior:

import sys, os, time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

def train(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    net = nn.Linear(10, 10).to(rank)
    net = DDP(net, device_ids=[rank])
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001)

    dataset = torch.utils.data.TensorDataset(torch.randn(100, 10), torch.randn(100, 10))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False, num_workers=0, sampler=torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank))
    print(f"[Rank {rank}] Dataset length: {len(dataset)}")

    for epoch in range(10):
        for i, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            out1 = net(x.to(rank))
            out2 = net(x.to(rank)**2)
            
            out1.backward(torch.rand_like(out1))
            out2.backward(torch.rand_like(out2))
            print(f"[Rank {rank}] Mlp weights grad: {list(net.parameters())[0].grad[:1]}", flush=True)
            optimizer.step()
            if i == 5:
                break
        break
    
    # ddp cleanup
    dist.destroy_process_group()


def main():
    world_size = 2
    print(f"World size: {world_size}")
    mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
    # command to launch torch distributed: torchrun --nproc_per_node=2 --nnodes=1 ddp_debug.py

if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    main() 

The output that I get by running the above script is:

World size: 2                                                                                                                                                      
World size: 2                                                                                                                                                      
[Rank 0] Dataset length: 100                                                                                                                                       
[Rank 1] Dataset length: 100                                                                                                                                       
[Rank 1] Mlp weights grad: tensor([[-0.1601,  2.1311,  0.8478,  2.0056,  0.6415,  2.2877, -0.3367,  3.1280,                                                        
         -0.4713,  2.4644]], device='cuda:1')                                                                                                                      
[Rank 0] Mlp weights grad: tensor([[0.6724, 5.8617, 0.7745, 0.4691, 0.5578, 0.5772, 3.6863, 1.4485, 1.6509,                                                        
         0.8201]], device='cuda:0')                                                                                                                                
[Rank 1] Mlp weights grad: tensor([[ 1.7129,  6.4606, 10.4901,  1.6184,  0.2531,  4.0623,  0.6212,  0.3135,                                                        
          2.0633,  3.4560]], device='cuda:1')                                                                                                                      
[Rank 0] Mlp weights grad: tensor([[0.3261, 0.5744, 2.4932, 0.1465, 0.7310, 2.1908, 0.4547, 0.0933, 3.3855,
         2.2625]], device='cuda:0')
[Rank 1] Mlp weights grad: tensor([[ 1.3302, -0.9185,  1.1055,  0.7271, -1.2709,  0.5219, -0.5395,  1.6983,
         -0.0840,  3.0954]], device='cuda:1')
[Rank 0] Mlp weights grad: tensor([[ 0.8065,  0.2661,  0.9931,  2.1985,  2.8956, -0.2228,  0.0563,  1.8179,
          0.4894,  2.0712]], device='cuda:0')
[Rank 1] Mlp weights grad: tensor([[3.3092, 2.4799, 2.0755, 1.0150, 5.7167, 3.4727, 3.1275, 1.0762, 1.9648,
         6.8734]], device='cuda:1')

Thanks for any help!