To replicate, change only def demo_basic(rank, world_size) in https://pytorch.org/tutorials/intermediate/ddp_tutorial.html to the following:
def demo_basic(rank, world_size):
print(f"Running basic DDP example on rank {rank}.")
setup(rank, world_size)
# create model and move it to GPU with id rank
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=1)
optimizer.zero_grad()
outputs = {}
outputs['0'] = ddp_model(torch.rand(20, 10))
outputs['1'] = ddp_model(torch.rand(20, 10))
outputs['2'] = ddp_model(torch.rand(20, 10))
labels = torch.rand(20, 5).to(rank)
for i in range(3):
print(f"before {i}, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}")
if i < 2:
loss_fn(outputs[str(i)], labels).backward(retain_graph=True)
else:
loss_fn(outputs[str(i)], labels).backward()
print(f"after {i}, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}, grad: {ddp_model.module.net1.weight.grad[0][0]}")
optimizer.step()
print(f"last, rank: {rank}, weight: {ddp_model.module.net1.weight[0][0]}, grad: {ddp_model.module.net1.weight.grad[0][0]}")
cleanup()
and the output is:
before 0, rank: 0, weight: 0.1450435221195221
before 0, rank: 3, weight: 0.1450435221195221
before 0, rank: 1, weight: 0.1450435221195221
before 0, rank: 2, weight: 0.1450435221195221
after 0, rank: 0, weight: 0.1450435221195221, grad: -0.018003715202212334
before 1, rank: 0, weight: 0.1450435221195221
after 0, rank: 3, weight: 0.1450435221195221, grad: -0.018003715202212334
after 0, rank: 1, weight: 0.1450435221195221, grad: -0.018003715202212334
after 0, rank: 2, weight: 0.1450435221195221, grad: -0.018003715202212334
before 1, rank: 1, weight: 0.1450435221195221
before 1, rank: 3, weight: 0.1450435221195221
before 1, rank: 2, weight: 0.1450435221195221
after 1, rank: 0, weight: 0.1450435221195221, grad: -0.03955963999032974
after 1, rank: 3, weight: 0.1450435221195221, grad: -0.03072114661335945
before 2, rank: 0, weight: 0.1450435221195221
before 2, rank: 3, weight: 0.1450435221195221
after 1, rank: 1, weight: 0.1450435221195221, grad: -0.03775426745414734
before 2, rank: 1, weight: 0.1450435221195221
after 1, rank: 2, weight: 0.1450435221195221, grad: -0.03235533833503723
before 2, rank: 2, weight: 0.1450435221195221
after 2, rank: 0, weight: 0.1450435221195221, grad: -0.06408560276031494
after 2, rank: 3, weight: 0.1450435221195221, grad: -0.04222358390688896
after 2, rank: 1, weight: 0.1450435221195221, grad: -0.056242190301418304
last, rank: 0, weight: 0.20912912487983704, grad: -0.06408560276031494
last, rank: 3, weight: 0.18726710975170135, grad: -0.04222358390688896
last, rank: 1, weight: 0.201285719871521, grad: -0.056242190301418304
after 2, rank: 2, weight: 0.1450435221195221, grad: -0.04413666948676109
last, rank: 2, weight: 0.1891801953315735, grad: -0.04413666948676109
Weights and grads do not seem to be synchronized.