The self-contained code below works for me.
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
def example(rank, world_size):
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Linear(10, 10).to(rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
# define loss function and optimizer
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
for _ in range((rank + 1) * 10):
signal = torch.tensor([1])
work = dist.all_reduce(signal, async_op=True)
# forward pass
outputs = ddp_model(torch.randn(20, 10).to(rank))
labels = torch.randn(20, 10).to(rank)
# backward pass
work.wait()
if signal.item() < world_size:
break
loss_fn(outputs, labels).backward()
# update parameters
optimizer.step()
if signal.item() >= world_size:
dist.all_reduce(torch.tensor([0]))
dist.barrier()
print(f"{rank} done")
def main():
world_size = 2
mp.spawn(example,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__=="__main__":
main()