Multiprocessing - Barrier Blocks all Processes?

The self-contained code below works for me.

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


def example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    for _ in range((rank + 1) * 10):
        signal = torch.tensor([1])
        work = dist.all_reduce(signal, async_op=True)
        # forward pass
        outputs = ddp_model(torch.randn(20, 10).to(rank))
        labels = torch.randn(20, 10).to(rank)
        # backward pass
        work.wait()
        if signal.item() < world_size:
            break
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()

    if signal.item() >= world_size:
        dist.all_reduce(torch.tensor([0]))

    dist.barrier()
    print(f"{rank} done")


def main():
    world_size = 2
    mp.spawn(example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()