Gradients across different ranks are not synchronized when using DDP

:bug: Bug

Gradients across different ranks are not synchronized when using DDP.
You can run the following code:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as functional
from torch.nn.parallel import DistributedDataParallel as DDP

import os
import numpy as np
import random

class TwoLayerMLP(nn.Module):
    def __init__(self, model_dim, feedford_dim):
        super(TwoLayerMLP, self).__init__()
        self.linear1 = nn.Linear(model_dim, feedford_dim, bias=False)
        self.linear2 = nn.Linear(feedford_dim, model_dim, bias=False)
    
    def forward(self, input):
        a1 = functional.relu(self.linear1(input))
        a2 = self.linear2(a1)
        return input + a2
    
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    # torch.backends.cudnn.deterministic = True

def run_fn(rank, world_size):
    print(f"Running DDP on rank {rank}.")
    # create default process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    
    setup_seed(1024)
    model_dim = 2
    feedford_dim = 8
    micro_batch_size = 3
    inputs_1 = torch.Tensor([[ 0.5920, -0.6301],
        [-0.8856,  1.2261],
        [-0.4671, -1.0279]]).to(rank)
    labels_1 = torch.Tensor([[-1.0387,  0.1039],
        [ 0.5989, -1.4801],
        [-0.8618, -0.9181]]).to(rank)
    inputs_2 = torch.Tensor([[-0.0355,  0.4145],
        [ 0.6798, -0.2936],
        [ 0.1872, -0.2724]]).to(rank)
    labels_2 = torch.Tensor([[-0.5524, -0.8358],
        [-2.8240,  0.2564],
        [ 0.5045, -1.1290]]).to(rank)
    inputs_3 = torch.Tensor([[-0.6166, -0.3604],
        [ 0.1046,  1.4810],
        [-0.2449,  1.1106]]).to(rank)
    labels_3 = torch.Tensor([[-0.3063, -1.3320],
        [ 0.7281,  0.1859],
        [ 0.5624, -1.4094]]).to(rank)
    inputs_1 = inputs_1 + rank
    labels_1 = labels_1 + rank
    inputs_2 = inputs_2 + rank
    labels_2 = labels_2 + rank
    inputs_3 = inputs_3 + rank
    labels_3 = labels_3 + rank
    # inputs_1.requires_grad_(True)
    # inputs_2.requires_grad_(True)
    # inputs_3.requires_grad_(True)

    inputs = [inputs_1, inputs_2, inputs_3]
    labels = [labels_1, labels_2, labels_3]

    loss_fn = nn.MSELoss()
    model = TwoLayerMLP(model_dim, feedford_dim).to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    outputs = [0] * micro_batch_size
    outputs[0] = ddp_model(inputs[0])
    with ddp_model.no_sync():
        for i in range(1, micro_batch_size):
            outputs[i] = ddp_model(inputs[i])
        for i in range(1, micro_batch_size):
            loss = loss_fn(outputs[i], labels[i])
            loss.backward()
    loss = loss_fn(outputs[0], labels[0])
    loss.backward()

    print("rank", rank, "backward:")
    for name, param in ddp_model.named_parameters():
        print("rank", rank, name, ":", param.grad)

    optimizer.step()

def main():
    world_size = 2
    mp.spawn(run_fn,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    # Environment variables which need to be
    # set when using c10d's default "env"
    # initialization mode.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    main()
    print("Done!")

And the results are:

Running DDP on rank 1.
Running DDP on rank 0.
rank 1 backward:
rank 0 backward:
rank 1 module.linear1.weight : tensor([[-0.1659,  1.0537],
        [-0.4481, -1.6936],
        [ 0.2067, -1.4908],
        [ 0.0000,  0.0000],
        [ 0.1104, -0.0408],
        [-0.4700, -0.1127],
        [-0.1041, -0.0192],
        [-0.0166, -0.3176]], device='cuda:1')
rank 1 module.linear2.weight : tensor([[ 0.3232, -0.8505,  0.4239,  0.0000,  0.0077,  0.6049,  0.9509, -0.0285],
        [ 2.0952,  2.1533,  1.7861,  0.0000,  0.0142, -0.1346, -0.2149,  0.0685]],
       device='cuda:1')
rank 0 module.linear1.weight : tensor([[-0.4369,  0.6137],
        [ 0.3596, -0.9361],
        [ 0.7006, -0.9808],
        [-0.2947,  0.4298],
        [ 0.2228, -0.1902],
        [-0.2892,  0.0564],
        [-0.1001, -0.0041],
        [ 0.1462, -0.3997]], device='cuda:0')
rank 0 module.linear2.weight : tensor([[ 0.5414, -0.6772,  0.6198, -0.0058,  0.2271,  0.8106,  0.9740, -0.2846],
        [ 0.7560,  1.8476,  0.5504,  0.0230,  0.0383, -0.1884, -0.1909,  0.7888]],
       device='cuda:0')
Done!

We can find that gradients across different ranks are not synchronized.

Versions

PyTorch version: 2.0.1+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.25.0
Libc version: glibc-2.31

Python version: 3.8.10 (default, Mar 13 2023, 10:26:41) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.11.0-41-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 11.8.89
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3060 Ti
GPU 1: NVIDIA GeForce RTX 3060 Ti

Nvidia driver version: 520.61.05
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.8.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.8.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 8
On-line CPU(s) list: 0-7
Thread(s) per core: 2
Core(s) per socket: 1
Socket(s): 4
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel Xeon Processor (Skylake, IBRS)
Stepping: 4
CPU MHz: 2399.971
BogoMIPS: 4799.94
Virtualization: VT-x
L1d cache: 128 KiB
L1i cache: 128 KiB
L2 cache: 16 MiB
L3 cache: 64 MiB
NUMA node0 CPU(s): 0-3
NUMA node1 CPU(s): 4-7
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Full generic retpoline, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology cpuid pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat umip pku ospke avx512_vnni md_clear

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] torch==2.0.1+cu118
[pip3] triton==2.0.0
[conda] Could not collect

You are not using checkpointing in the posted code snippet, so checkpointing should be unrelated.

The error is most likely raised by not performing a full foward/backward call outside the no_sync context as mentioned in the docs:

Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context.

Moving outputs[0] = ddp_model(inputs[0]) down after the no_sync guard shows same gradients.

Thanks, I wrote wrong. Checkpointing is unrelated.
It can work fine if moving outputs[0] = ddp_model(inputs[0]) down. But I want to know how to synchronise it if not moving outputs[0] = ddp_model(inputs[0]) down.

I believe the forward-backward pass after exiting the context is needed as described in the docs, but @kwen2501 can correct me.