Does GLOO on CUDA tensors support reduce op?

It seems that from the doc here: Distributed communication package - torch.distributed — PyTorch 1.9.0 documentation, that GLOO backended lib on CUDA tensors doesn’t support reduce, only support all_reduce and broadcast.
But as I followed the tutorial here: Writing Distributed Applications with PyTorch — PyTorch Tutorials 1.9.0+cu102 documentation.
With a simple run function implemented bellow:

def run4(rank, size):
    """ run4: CUDA reduction. """
    n_gpus = torch.cuda.device_count()
    t = torch.ones(1).cuda(rank % n_gpus)
    for _ in range(1):
        c = t.clone()
        # dist.all_reduce(c, dist.ReduceOp.SUM)
        dist.reduce(c, dst=0, op=dist.ReduceOp.SUM)
        t.set_(c)
    print('[{}] After reduction: rank {} has data {}, backend is {}'.format(os.getpid(), rank, t, dist.get_backend()))

And I tried with a world_size of 4 on a machine with only 2 GPU cards, here is the result:

[30425] After reduction: rank 2 has data tensor([2.], device='cuda:0'), backend is gloo
[30424] After reduction: rank 1 has data tensor([3.], device='cuda:1'), backend is gloo
[30426] After reduction: rank 3 has data tensor([1.], device='cuda:1'), backend is gloo
[30423] After reduction: rank 0 has data tensor([4.], device='cuda:0'), backend is gloo

Based on the results above, I assume reduce has worked on GPU because my tensors are put on CUDA deivces, am I right? However, I noticed that all the ranks have participated in the reduce algo. Meaning that rank 1, 2, 3 tensor values have also changed when performing the reduction. It seems that the reduction algo is quite naive, if I have 4 processes, it would run 3 rounds.
1st round, add rank 3 to rank 2.
2nd round, add rank 2 to rank 1.
3rd round, add rank 1 to rank 0.
So when I print the end result, not only rank 0 has the desired reduced sum value, but also rank 1 ~ (world_size -2) has also changed the value. Is this the supposed result of reduction? I thought rank 1 ~ (world_size -2) doesn’t store the imtermediate results. :sweat_smile:

Based on the results above, I assume reduce has worked on GPU because my tensors are put on CUDA deivces, am I right? However, I noticed that all the ranks have participated in the reduce algo. Meaning that rank 1, 2, 3 tensor values have also changed when performing the reduction. It seems that the reduction algo is quite naive, if I have 4 processes, it would run 3 rounds.
1st round, add rank 3 to rank 2.
2nd round, add rank 2 to rank 1.
3rd round, add rank 1 to rank 0.
So when I print the end result, not only rank 0 has the desired reduced sum value, but also rank 1 ~ (world_size -2) has also changed the value. Is this the supposed result of reduction? I thought rank 1 ~ (world_size -2) doesn’t store the imtermediate results. :sweat_smile:

Hi the documentation might need to be updated.

I ran the reduce operation using gloo and nccl.

gloo reduce runs

[75745] After reduction: rank 0 has data tensor([4.], device=‘cuda:0’), backend is gloo
[75747] After reduction: rank 2 has data tensor([2.], device=‘cuda:2’), backend is gloo
[75746] After reduction: rank 1 has data tensor([3.], device=‘cuda:1’), backend is gloo
[75748] After reduction: rank 3 has data tensor([1.], device=‘cuda:3’), backend is gloo

[76980] After reduction: rank 2 has data tensor([2.]), backend is gloo
[76981] After reduction: rank 3 has data tensor([1.]), backend is gloo
[76979] After reduction: rank 1 has data tensor([3.]), backend is gloo
[76978] After reduction: rank 0 has data tensor([4.]), backend is gloo

nccl reduce runs

[75986] After reduction: rank 1 has data tensor([1.], device=‘cuda:1’), backend is nccl
[75985] After reduction: rank 0 has data tensor([4.], device=‘cuda:0’), backend is nccl
[75987] After reduction: rank 2 has data tensor([1.], device=‘cuda:2’), backend is nccl
[75988] After reduction: rank 3 has data tensor([1.], device=‘cuda:3’), backend is nccl

The result of the gloo reduce operation is wrong. It should be as described in the tutorial. This is a known bug, and there is an open issue for it ProcessGroupGloo reduce produces wrong result · Issue #21480 · pytorch/pytorch · GitHub.

From the doc, Distributed communication package - torch.distributed — PyTorch 1.9.0 documentation,
seems like no backend supports send and recv with CUDA tensors. But I still tried to test it. :sweat_smile:
So, here is my run function:

def run1(rank, size):
    """ run1: Simple P2P synchronously."""
    tensor = torch.zeros(1).cuda(rank)
    if rank == 0:
        tensor += 1
        # Send the tensor to process 1
        dist.send(tensor=tensor, dst=1)
    else:
        # Receive the tensor from process 0
        # tensor += 10
        # dist.send(tensor=tensor, dst=1)
        dist.recv(tensor=tensor, src=0)
        # dist.recv(tensor=tensor)
    print("Rank {} has data {}, with addr {}".format(rank, tensor[0], tensor.data_ptr()))

With 2 processes, NCCL backend, it seems I can get correct results:

root@298562e873aa:/opt/sw_home/pytorch-distributed# python distributed.py -f 1 -b nccl
Rank 1 has data 1.0, with addr 139654443565056
Rank 0 has data 1.0, with addr 139731618758656

However, with 2 processes, gloo backend, I get runtime errors:

root@298562e873aa:/opt/sw_home/pytorch-distributed# python distributed.py -f 1 -b gloo
Process Process-2:
Process Process-1:
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/sw_home/pytorch-distributed/distributed.py", line 98, in init_process
    fn(rank, size)
  File "/opt/sw_home/pytorch-distributed/distributed.py", line 41, in run1
    dist.recv(tensor=tensor, src=0)
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/distributed_c10d.py", line 850, in recv
    pg.recv([tensor], src, tag).wait()
RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:575] Connection closed by peer [172.17.0.13]:31389
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/sw_home/pytorch-distributed/distributed.py", line 98, in init_process
    fn(rank, size)
  File "/opt/sw_home/pytorch-distributed/distributed.py", line 36, in run1
    dist.send(tensor=tensor, dst=1)
  File "/usr/local/lib/python3.6/dist-packages/torch/distributed/distributed_c10d.py", line 805, in send
    default_pg.send([tensor], dst, tag).wait()
RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:378] writev [172.17.0.13]:5547: Bad address

So, I’m not sure with CUDA tensors, can we use send, recv for NCCL or GLOO? Then actual reason why I do these experiments is that I’m running on a machine without CUDA UVA support, so it only supports cudaMemcpyPeer(Async), not cudaMemcpy(Async). But as I checked in gloo source code, there is no cudaMemcpyPeer usage at all, only cudaMemcpyAsync. Thus I’m not sure if pytorch DDP with gloo with CUDA tensors will work as expected.

Gloo supports send/recv CPU tensors only, NCCL supports send/recv CUDA tensors only

1 Like