When the communication graph between threads is connected, the program will get stucked.
It seems like a deadlock.
But I don’t know how to figure it out.
I write a simple demo to reproduce it.
import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
def run(rank, size):
tensor = torch.zeros(1)
rec_tensor = torch.zeros(1)
tensor += rank
if rank == 0:
dist.send(tensor=tensor, dst=1)
dist.recv(tensor=rec_tensor, src=2)
elif rank == 1:
dist.send(tensor=tensor, dst=2)
dist.recv(tensor=rec_tensor, src=0)
else:
dist.send(tensor=tensor, dst=0)
dist.recv(tensor=rec_tensor, src=1)
pass
print("Rank ", rank, " has data ", rec_tensor)
def init_processes(rank, size, fn, backend="gloo"):
"""Initialize the distributed environment."""
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend, rank=rank, world_size=size)
fn(rank, size)
if __name__ == "__main__":
size = 3
processes = []
torch.multiprocessing.set_start_method("spawn")
for rank in range(size):
p = Process(target=init_processes, args=(rank, size, run))
p.start()
processes.append(p)
for p in processes:
p.join()