Hey, I am trying write a custom synchronization scheme with Pytorch in python. I am training a mode and at the same time spawning a thread that waits for synchronization requests from other parallel processes. But the problem is my main process get stuck at dist.send() call.
A equivalent code is as follows:
def run(rank_local, rank, world_size):
print("I WAS SPAWNED ", rank_local, " OF ", rank)
tensor_1 = torch.zeros(1)
tensor_1 += 1
while True:
print("I am spawn of: ", rank, "and my tensor value before receive: ", tensor_1[0])
nt = dist.recv(tensor_1)
print("I am spawn of: ", rank, "and my tensor value after receive from", nt, " is: ", tensor_1[0])
def communication(tensor, rank):
if rank != 0:
tensor += (100 + rank)
dist.send(tensor, dst=0)
else:
tensor -= 1000
dist.send(tensor, dst=1)
print("I AM DONE WITH MY SENDS NOW WHAT?: ", rank)
if __name__ == '__main__':
# Initialize Process Group
dist.init_process_group(backend="mpi", group_name="main")
# get current process information
world_size = dist.get_world_size()
rank = dist.get_rank()
#torch.cuda.set_device(rank%2)
# Establish Local Rank and set device on this node
p = ThreadWithReturnValue(target=run, args=(0, rank, world_size)) #mp.Process(target=run, args=(0, rank, world_size))
p.start()
tensor = torch.zeros(1)
communication(tensor, rank)
p.join()
Please note that if I remove the line torch.cuda.set_device(rank%2)
the code works perfectly fine. Any thoughts? Why am I not able to achieve the same behavior with CUDA?