What is the difference between wait and barrier in torch distributed?

Does handle.wait() block and synchronize all processes like torch.distributed.barrier()? I have having trouble understanding wait and barrier`. Based on the description in torch documentation, I thought wait would also synchronize? Am I misunderstanding? I need to synchronize the results in all processes before continuing. In particular, what are the differences between the 3 blocks of code below.

x = func()
handle = torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM, async_op=True)
handle.wait()

x = func()
handle = torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM, async_op=True)
handle.wait()
torch.distributed.barrier()

x = func()
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM, async_op=False)
torch.distributed.barrier()

They are in a sense similar, but serve different purposes. wait() ensures that, once returned, the async operation that its is associated with has completed. This means:

x = func()
torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM, async_op=False)

and

x = func()
handle = torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM, async_op=True)
handle.wait()

are in fact equivalent. The advantage of the async version is that you can have additional logic before you call wait() to overlap communication and computation.

On the other hand barrier() is a standalone collective function. You use it mostly to ensure that all processes in your job reach a certain point in execution (e.g. for coordinating checkpointing). The calls to barrier() in your second and third examples are redundant though, since collective reduce operations also implicitly represent a barrier.

1 Like