How to timeout all_reduce or prevent it from hangs

Hi,

Anyone knows how to stop / terminate an all-reduce call properly when it doesn’t get reply from other processes?

Just to explain my question, please see the sample code below. I have 4 processes divided into two sub-groups (group1 and group2), and a shared Queue with 5 elements. Each process will try to get one element from the queue in the while loop until the queue becomes empty. And inside the while loop, each process will do all_reduce with its “neighbor” in the same sub-group.
The problem is that when one of the process get the last element, and at the same time, the shared queue is empty and its “neighbor” process already exit the while loop, it will get hanged and waiting forever for the all_reduce reply. Is there any way to set timeout for all_reduce call? Or some other ways to solve this situation?

Thanks. Please see codes attached below.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def run(rank, a, q):
    dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
                master_ip='127.0.0.1', master_port='12346')
    world_size = 4
    torch.distributed.init_process_group(backend="nccl",
                                        init_method=dist_init_method,
                                        world_size=world_size,
                                        rank=rank)
    group1 = dist.new_group([0, 1])
    group2 = dist.new_group([2, 3])
    tensor = torch.ones(1)
    device = torch.device('cuda', rank)
    tensor = tensor.to(device)
    
    while not q.empty():
        current_index = q.get()
        print(f'Process {rank} current index is: {current_index}')
        
        if rank == 0 or rank == 1:
            dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group1)
        else:
            dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group2)
    
    print('Rank ', rank, ' has data ', tensor[0])

if __name__ == "__main__":
    a = 1
    ctx = mp.get_context('spawn')
    q = ctx.Queue()
    for index in range(5):
        q.put(index)
    
    mp.spawn(run, args=(a, q), nprocs=4)

Hey @Yi_Zhang, you can set a timeout in init_process_group. For NCCL backend, it also requires setting NCCL_BLOCKING_WAIT env var to 1.

More explanation can be found here https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group. (search for NCCL_BLOCKING_WAIT )

1 Like

Hi @mrshenli, thanks for your reply.
I tried with “gloo”, it does terminate the process and throws Exceptions. But with “nccl”, I add the following line in the bash script to run the .py file,

export NCCL_BLOCKING_WAIT=1

However it doesn’t work.
Also tried with adding the following line in the .py file itself, but doesn’t work either.

os.environ["NCCL_BLOCKING_WAIT"] = "1"

Hey @Yi_Zhang, did you set the env var within each spawned process (i.e., in run function) and before calling init_process_group?

Hi @mrshenli, yes, I did that for each process. Don’t understand where is the mistake. Here is the sample code:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time
import datetime

def run(rank, a, q):
   os.environ["NCCL_BLOCKING_WAIT"] = "1"
   print('Rank ', rank, 'NCCL_BLOCKING_WAIT is: ', os.environ["NCCL_BLOCKING_WAIT"])
   dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
               master_ip='127.0.0.1', master_port='12346')
   torch.distributed.init_process_group(backend="nccl",
                                       init_method=dist_init_method,
                                       timeout=datetime.timedelta(seconds=5),
                                       world_size=4,
                                       rank=rank)
   group1 = dist.new_group([0, 1])
   group2 = dist.new_group([2, 3])
   tensor = torch.ones(1)

   device = torch.device('cuda', rank)
   tensor = tensor.to(device)

   while not q.empty():
       print('Rank ', rank, ' in the loop ')
       current_index = q.get()
       print(f'Process {rank} current index is: {current_index}')
       try:
           if rank == 0 or rank == 1:
               dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group1)
               print(f'Process {rank} all_reduce tensor is: {tensor}')
           else:
               dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group2)
               print(f'Process {rank} all_reduce tensor is: {tensor}')
       except Exception:
           pass
   print('Rank ', rank, ' has data ', tensor[0])

if __name__ == "__main__":
   a = 1
   ctx = mp.get_context('spawn')
   q = ctx.Queue()
   flag = ctx.Queue()
   for index in range(5):
       q.put(index)
   
   mp.spawn(run, args=(a, q), nprocs=4)

An update, I tried “gloo” without setting timeout, and it can terminate properly. I’m wondering maybe “gloo” takes care of this situation by itself? It doesn’t have something to do with the “timeout”?

I confirm that I can reproduce this locally.

Hey @osalpekar, do you know if we miss anything here? Would I be correct if I assume the following code is expected to abort the op in this case?

@mrshenli - Yes that is the code block that should abort the op if it times out.

@Yi_Zhang - There is a workaround. The all_reduce call actually returns an async work handle. You can capture that handle and wait on it as such:

work = dist.all_reduce(..., async_op=True)
work.wait(SOME_TIMEOUT)

If the all_reduce call times out, then the wait call will throw an exception.

In the meantime, let me try to repro from your most recent code snippet.

2 Likes

@mrshenli @osalpekar, thanks for your reply. I find another way to avoid this situation without using timeout. I just add some checks to make sure the pairs of processes terminate after same number of rounds. But I’m still curious to know if you have any answer for the timeout issue. Thanks

2 Likes