Multiple threads on GPU not working?

Hey, I am trying write a custom synchronization scheme with Pytorch in python. I am training a mode and at the same time spawning a thread that waits for synchronization requests from other parallel processes. But the problem is my main process get stuck at dist.send() call.

A equivalent code is as follows:

def run(rank_local, rank, world_size):
    print("I WAS SPAWNED ", rank_local, " OF ", rank)
    
    tensor_1 = torch.zeros(1)
    tensor_1 += 1

    while True:
        print("I am spawn of: ", rank, "and my tensor value before receive: ", tensor_1[0])
        nt = dist.recv(tensor_1)
        print("I am spawn of: ", rank, "and my tensor value after  receive from", nt, " is: ", tensor_1[0])
    
def communication(tensor, rank):

    if rank != 0:
        tensor += (100 + rank)
        dist.send(tensor, dst=0)
    else:
        tensor -= 1000
        dist.send(tensor, dst=1)
    print("I AM DONE WITH MY SENDS NOW WHAT?: ", rank)


if __name__ == '__main__':
    
    # Initialize Process Group
    dist.init_process_group(backend="mpi", group_name="main")
    
    # get current process information
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    
    #torch.cuda.set_device(rank%2)

    # Establish Local Rank and set device on this node
    p = ThreadWithReturnValue(target=run, args=(0, rank, world_size))  #mp.Process(target=run, args=(0, rank, world_size))
    p.start()

    tensor = torch.zeros(1)
    communication(tensor, rank)
    
    p.join()

Please note that if I remove the line torch.cuda.set_device(rank%2) the code works perfectly fine. Any thoughts? Why am I not able to achieve the same behavior with CUDA?

In your snippet dist.send and dist.recv will race. Since you’re using blocking calls, either path will not continue until it has finished. There must be some interaction between MPI and CUDA that means these calls are now really blocking. You could try explicitly sequencing them and see if that changes anything. In general, it is not encouraged to use Python threading in combination with a single distributed context. It is not written to be thread safe.