Dist.all_gather leads to deadlock

Hi everyone,
If all_gather are used more than once, the next cuda() function will deadlock. And GPU-Utils are 100%.
Details:
I wrapped all_gather in collect:

def collect(x):
    x = x.contiguous()
    out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype)
        for _ in range(dist.get_world_size())]
dist.all_gather(out_list, x)
return torch.cat(out_list, dim=0)

Next:

a_all = collect(a)
b_all = collect(b)
c = torch.rand(10,10).cuda() # deadlock! No error report.

The deadlock doesn’t happen if I don’t collect b_all or don’t use cuda().
My code runs on 4 GPUs using DDP.
Pytorch version: 1.4.
This problem has been confusing me for a couple of days. I’ll appreciate any help!

Could you provide a small repro for the problem that you are seeing? I ran the following script locally on my GPU machine and didn’t notice any deadlocks:

import os
import torch
import torch.distributed as dist
from torch.multiprocessing import Process

def collect(x):
    x = x.contiguous()
    out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype)
        for _ in range(dist.get_world_size())]
    dist.all_gather(out_list, x)
    return torch.cat(out_list, dim=0)

def run(rank, size):
    """ Distributed function to be implemented later. """
    print ('START: {}'.format(rank))
    a = torch.rand(10, 10).cuda(rank)
    b = torch.rand(10, 10).cuda(rank)
    a_all = collect(a)
    b_all = collect(b)
    c = torch.rand(10,10).cuda(rank)
    print ('DONE : {}'.format(rank))

def init_process(rank, size, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size)

if __name__ == "__main__":
    size = 4
    processes = []
    for rank in range(size):
        p = Process(target=init_process, args=(rank, size, run))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

Running the script:

$ python /tmp/test_allgather.py
START: 0
START: 1
START: 2
START: 3
DONE : 2
DONE : 3
DONE : 1
DONE : 0
1 Like