DistributedDataParallel: model weights and grads not synchronized with multiple forward backward pass

To replicate, change only def demo_basic(rank, world_size) in https://pytorch.org/tutorials/intermediate/ddp_tutorial.html to the following:

def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=1)
    
    optimizer.zero_grad()
    
    outputs = {}
    outputs['0'] = ddp_model(torch.rand(20, 10))
    outputs['1'] = ddp_model(torch.rand(20, 10))
    outputs['2'] = ddp_model(torch.rand(20, 10))

    labels = torch.rand(20, 5).to(rank)

    for i in range(3):
        print(f"before {i}, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}")
                 
        if i < 2:
            loss_fn(outputs[str(i)], labels).backward(retain_graph=True)
        else:
            loss_fn(outputs[str(i)], labels).backward()

        print(f"after {i}, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}, grad: {ddp_model.module.net1.weight.grad[0][0]}")
                
    optimizer.step()
    print(f"last, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}, grad: {ddp_model.module.net1.weight.grad[0][0]}")
    
    cleanup()

and the output is:

before 0, rank: 0, weight: 0.1450435221195221
before 0, rank: 3, weight: 0.1450435221195221
before 0, rank: 1, weight: 0.1450435221195221
before 0, rank: 2, weight: 0.1450435221195221
after 0, rank: 0, weight: 0.1450435221195221, grad: -0.018003715202212334
before 1, rank: 0, weight: 0.1450435221195221
after 0, rank: 3, weight: 0.1450435221195221, grad: -0.018003715202212334
after 0, rank: 1, weight: 0.1450435221195221, grad: -0.018003715202212334
after 0, rank: 2, weight: 0.1450435221195221, grad: -0.018003715202212334
before 1, rank: 1, weight: 0.1450435221195221
before 1, rank: 3, weight: 0.1450435221195221
before 1, rank: 2, weight: 0.1450435221195221
after 1, rank: 0, weight: 0.1450435221195221, grad: -0.03955963999032974
after 1, rank: 3, weight: 0.1450435221195221, grad: -0.03072114661335945
before 2, rank: 0, weight: 0.1450435221195221
before 2, rank: 3, weight: 0.1450435221195221
after 1, rank: 1, weight: 0.1450435221195221, grad: -0.03775426745414734
before 2, rank: 1, weight: 0.1450435221195221
after 1, rank: 2, weight: 0.1450435221195221, grad: -0.03235533833503723
before 2, rank: 2, weight: 0.1450435221195221
after 2, rank: 0, weight: 0.1450435221195221, grad: -0.06408560276031494
after 2, rank: 3, weight: 0.1450435221195221, grad: -0.04222358390688896
after 2, rank: 1, weight: 0.1450435221195221, grad: -0.056242190301418304
last, rank: 0, weight: 0.20912912487983704, grad: -0.06408560276031494
last, rank: 3, weight: 0.18726710975170135, grad: -0.04222358390688896
last, rank: 1, weight: 0.201285719871521, grad: -0.056242190301418304
after 2, rank: 2, weight: 0.1450435221195221, grad: -0.04413666948676109
last, rank: 2, weight: 0.1891801953315735, grad: -0.04413666948676109

Weights and grads do not seem to be synchronized.

Hey @zzzf,

Does this problem persistent, if you change the flow of fw(1)->fw(2)->fw(3)->bw(1)->bw(2)->bw(3) to fw([1, 2, 3]) -> bw([1, 2, 3])?

BTW, which version of PyTorch are you using? If it’s <=v1.6, I would expect it throws an error here:

Hi @mrshenli,

Thanks for your reply. My torch version is 1.6.0 and I receive no warnings/errors.

I’ve tried:

for _ in range(3):
    output = ddp_model(torch.rand(20, 10)
    loss_fn.backward(output, labels)

optimizer.step()

This works as I expected.

When you said fw([1, 2, 3]) -> bw([1, 2, 3]), do you mean the following?

outputs = (ddp_model(torch.rand(20, 10)), ddp_model(torch.rand(20, 10)), ddp_model(torch.rand(20, 10)))

for i in range(3):
     (loss_fn(outputs[0], labels).backward(retain_graph=True),  loss_fn(outputs[1], labels).backward(retain_graph=True), loss_fn(outputs[2], labels).backward(retain_graph=True))

optimizer.step()

This still doesn’t synchronize the weights and doesn’t throw any error.

@mrshenli This seems to be a gap in DDP where it doesn’t support running backward twice? I couldn’t find any tests that use retain_graph=True with DDP.

The problem seems to be this line: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/reducer.cpp#L530, which skips any gradient reduction. This variable is set to false after the first backward is done: https://github.com/pytorch/pytorch/blob/master/torch/csrc/distributed/c10d/reducer.cpp#L1152 and then never set to True again since prepare_for_backward is not called anymore.

@zzzf Is the workaround you mentioned in your previous reply sufficient for now?