Hi,
I have a short script that sends a tensor from rank 0 to rank 1, and another tensor from rank 1 to rank 0 using calls to dist.broadcast
with async_op=True
. This script hangs at the dist.barrier()
call if the size of the tensors being communicated is 128MB. However, it executes successfully if the size of the tensors is smaller (e.g. if we set numel = 16 * 1024 * 1024
in the script below).
It also executes successfully if the order of the dist.broadcast
calls is altered so that both ranks perform the broadcast to send data from rank 0 → rank 1 first, and then the broadcast from rank 1 → rank 0 next, rather than the current ordering where the calls are interleaved.
Also, calling is_completed()
on the async work handles returns False
after .wait()
has returned which seems to be in conflict with the distributed docs which state that “wait()
- will block the process until the operation is finished. is_completed()
is guaranteed to return True once it returns.”
I’m wondering what could explain this behavior?
I’m using PyTorch 2.1.0 with 2x A100 80GB SXM GPUs.
Note: this is a minimal repro from a larger program, so although this script is equivalent to just doing an AllGather, the larger program requires this interleaved pattern of dist.broadcast
calls.
Script:
import logging
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def run_broadcast_ops(rank: int, world_size: int):
logging.getLogger().setLevel(logging.INFO)
assert world_size == 2, world_size
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29507'
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
logging.info(f"Rank {dist.get_rank()} has joined the process group")
rank = dist.get_rank()
device = f"cuda:{rank}"
dtype = torch.int8
numel = 128 * 1024 * 1024
send1 = torch.zeros(numel, dtype=dtype, device=device)
recv1 = torch.zeros(numel, dtype=dtype, device=device)
if rank == 0:
send1_handle = dist.broadcast(send1, src=0, async_op=True)
recv1_handle = dist.broadcast(recv1, src=1, async_op=True)
send1_handle.wait()
recv1_handle.wait()
logging.info(f"[Rank 0] {send1_handle.wait() = }")
logging.info(f"[Rank 0] {send1_handle.is_completed() = }")
logging.info(f"[Rank 0] {recv1_handle.wait() = }")
logging.info(f"[Rank 0] {recv1_handle.is_completed() = }")
elif rank == 1:
send1_handle = dist.broadcast(send1, src=1, async_op=True)
recv1_handle = dist.broadcast(recv1, src=0, async_op=True)
send1_handle.wait()
recv1_handle.wait()
logging.info(f"[Rank 1] {send1_handle.wait() = }")
logging.info(f"[Rank 1] {send1_handle.is_completed() = }")
logging.info(f"[Rank 1] {recv1_handle.wait() = }")
logging.info(f"[Rank 1] {recv1_handle.is_completed() = }")
else:
raise ValueError(f"Invalid {rank=}")
logging.info(f"Rank {rank} is waiting at barrier")
dist.barrier()
logging.info(f"Rank {rank} finished waiting at barrier")
def main():
world_size = 2
processes = []
mp.set_start_method("spawn")
for rank in range(world_size):
p = mp.Process(target=run_broadcast_ops, args=(rank, world_size))
p.start()
processes.append(p)
for p in processes:
p.join()
if __name__ == "__main__":
main()
Output:
INFO:root:Rank 1 has joined the process group
INFO:root:Rank 0 has joined the process group
INFO:root:[Rank 0] send1_handle.wait() = True
INFO:root:[Rank 0] send1_handle.is_completed() = False
INFO:root:[Rank 0] recv1_handle.wait() = True
INFO:root:[Rank 0] recv1_handle.is_completed() = False
INFO:root:Rank 0 is waiting at barrier
INFO:root:[Rank 1] send1_handle.wait() = True
INFO:root:[Rank 1] send1_handle.is_completed() = False
INFO:root:[Rank 1] recv1_handle.wait() = True
INFO:root:[Rank 1] recv1_handle.is_completed() = False
INFO:root:Rank 1 is waiting at barrier