Anyone knows how to stop / terminate an all-reduce call properly when it doesn’t get reply from other processes?
Just to explain my question, please see the sample code below. I have 4 processes divided into two sub-groups (group1 and group2), and a shared Queue with 5 elements. Each process will try to get one element from the queue in the while loop until the queue becomes empty. And inside the while loop, each process will do all_reduce with its “neighbor” in the same sub-group.
The problem is that when one of the process get the last element, and at the same time, the shared queue is empty and its “neighbor” process already exit the while loop, it will get hanged and waiting forever for the all_reduce reply. Is there any way to set timeout for all_reduce call? Or some other ways to solve this situation?
Thanks. Please see codes attached below.
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def run(rank, a, q):
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12346')
world_size = 4
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=rank)
group1 = dist.new_group([0, 1])
group2 = dist.new_group([2, 3])
tensor = torch.ones(1)
device = torch.device('cuda', rank)
tensor = tensor.to(device)
while not q.empty():
current_index = q.get()
print(f'Process {rank} current index is: {current_index}')
if rank == 0 or rank == 1:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group1)
else:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group2)
print('Rank ', rank, ' has data ', tensor[0])
if __name__ == "__main__":
a = 1
ctx = mp.get_context('spawn')
q = ctx.Queue()
for index in range(5):
q.put(index)
mp.spawn(run, args=(a, q), nprocs=4)
Hi @mrshenli, thanks for your reply.
I tried with “gloo”, it does terminate the process and throws Exceptions. But with “nccl”, I add the following line in the bash script to run the .py file,
export NCCL_BLOCKING_WAIT=1
However it doesn’t work.
Also tried with adding the following line in the .py file itself, but doesn’t work either.
Hi @mrshenli, yes, I did that for each process. Don’t understand where is the mistake. Here is the sample code:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time
import datetime
def run(rank, a, q):
os.environ["NCCL_BLOCKING_WAIT"] = "1"
print('Rank ', rank, 'NCCL_BLOCKING_WAIT is: ', os.environ["NCCL_BLOCKING_WAIT"])
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12346')
torch.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
timeout=datetime.timedelta(seconds=5),
world_size=4,
rank=rank)
group1 = dist.new_group([0, 1])
group2 = dist.new_group([2, 3])
tensor = torch.ones(1)
device = torch.device('cuda', rank)
tensor = tensor.to(device)
while not q.empty():
print('Rank ', rank, ' in the loop ')
current_index = q.get()
print(f'Process {rank} current index is: {current_index}')
try:
if rank == 0 or rank == 1:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group1)
print(f'Process {rank} all_reduce tensor is: {tensor}')
else:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM, group=group2)
print(f'Process {rank} all_reduce tensor is: {tensor}')
except Exception:
pass
print('Rank ', rank, ' has data ', tensor[0])
if __name__ == "__main__":
a = 1
ctx = mp.get_context('spawn')
q = ctx.Queue()
flag = ctx.Queue()
for index in range(5):
q.put(index)
mp.spawn(run, args=(a, q), nprocs=4)
An update, I tried “gloo” without setting timeout, and it can terminate properly. I’m wondering maybe “gloo” takes care of this situation by itself? It doesn’t have something to do with the “timeout”?
@mrshenli@osalpekar, thanks for your reply. I find another way to avoid this situation without using timeout. I just add some checks to make sure the pairs of processes terminate after same number of rounds. But I’m still curious to know if you have any answer for the timeout issue. Thanks