All_reduce with NCCL timeouts for large tensor

Consider the following MWE, where I attempt to simply sum random tensors that are generated in different GPUs. If I generate tensors of size e.g., 5.000, it still works, but for size 10.000 it timeouts.

I’m using CUDA 11.2 w/ 4 RTX A6000. Tried both torch-1.9.1+cu111 and the nightly one compiled directly from the repo. Note that if I use gloo as the backend, then it works.

Is this a bug, or maybe there is something wrong with my environment? Any idea of what could I try? Thanks.

import os

import torch
import torch.distributed as dist
import torch.multiprocessing as mp


def fn(rank, world_size):
    # Set up distributed job
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '11235'
    dist.init_process_group('nccl', rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

    # Generate random tensor and do all reduce
    x = torch.randn(10_000, device='cuda')
    dist.all_reduce(tensor=x)
    print(x[0])

    dist.destroy_process_group()

if __name__ == '__main__':
    if not torch.cuda.is_available():
        print('No GPUs available, cannot run test')
    else:
        n_gpus = torch.cuda.device_count()
        mp.spawn(
            fn,
            args=(n_gpus,),
            nprocs=n_gpus,
            join=True,
        )

Maybe you could try export NCCL_DEBUG=INFO to get more information about the error?

There you go, this is the output when I do export NCCL_DEBUG=INFO. Note that all these messages only show up once all_reduce is called (I’ve tried adding a time.sleep just before it).

I can’t see anything wrong in the logs though…

<hostname>:3335158:3335158 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
<hostname>:3335158:3335158 [0] NCCL INFO NET/IB : No device found.
<hostname>:3335158:3335158 [0] NCCL INFO NET/Socket : Using [0]eno1np0:<ip-address><0>
<hostname>:3335158:3335158 [0] NCCL INFO Using network Socket
NCCL version 2.7.8+cuda11.1
<hostname>:3335159:3335159 [1] NCCL INFO Bootstrap : Using [0]eno1np0:<ip-address><0>
<hostname>:3335161:3335161 [3] NCCL INFO Bootstrap : Using [0]eno1np0:<ip-address><0>
<hostname>:3335160:3335160 [2] NCCL INFO Bootstrap : Using [0]eno1np0:<ip-address><0>
<hostname>:3335159:3335159 [1] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
<hostname>:3335161:3335161 [3] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
<hostname>:3335160:3335160 [2] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
<hostname>:3335159:3335159 [1] NCCL INFO NET/IB : No device found.
<hostname>:3335161:3335161 [3] NCCL INFO NET/IB : No device found.
<hostname>:3335160:3335160 [2] NCCL INFO NET/IB : No device found.
<hostname>:3335159:3335159 [1] NCCL INFO NET/Socket : Using [0]eno1np0:<ip-address><0>
<hostname>:3335159:3335159 [1] NCCL INFO Using network Socket
<hostname>:3335161:3335161 [3] NCCL INFO NET/Socket : Using [0]eno1np0:<ip-address><0>
<hostname>:3335161:3335161 [3] NCCL INFO Using network Socket
<hostname>:3335160:3335160 [2] NCCL INFO NET/Socket : Using [0]eno1np0:<ip-address><0>
<hostname>:3335160:3335160 [2] NCCL INFO Using network Socket
<hostname>:3335158:3335216 [0] NCCL INFO Channel 00/04 :    0   1   2   3
<hostname>:3335160:3335219 [2] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/64
<hostname>:3335158:3335216 [0] NCCL INFO Channel 01/04 :    0   3   2   1
<hostname>:3335161:3335218 [3] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/64
<hostname>:3335158:3335216 [0] NCCL INFO Channel 02/04 :    0   1   2   3
<hostname>:3335160:3335219 [2] NCCL INFO Trees [0] -1/-1/-1->2->1|1->2->-1/-1/-1 [1] 3/-1/-1->2->1|1->2->3/-1/-1 [2] -1/-1/-1->2->1|1->2->-1/-1/-1 [3] 3/-1/-1->2->1|1->2->3/-1/-1
<hostname>:3335159:3335217 [1] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/64
<hostname>:3335158:3335216 [0] NCCL INFO Channel 03/04 :    0   3   2   1
<hostname>:3335160:3335219 [2] NCCL INFO Setting affinity for GPU 2 to ffffffff,ffffffff,ffffffff,ffffffff
<hostname>:3335161:3335218 [3] NCCL INFO Trees [0] 1/-1/-1->3->0|0->3->1/-1/-1 [1] 0/-1/-1->3->2|2->3->0/-1/-1 [2] 1/-1/-1->3->0|0->3->1/-1/-1 [3] 0/-1/-1->3->2|2->3->0/-1/-1
<hostname>:3335159:3335217 [1] NCCL INFO Trees [0] 2/-1/-1->1->3|3->1->2/-1/-1 [1] 2/-1/-1->1->-1|-1->1->2/-1/-1 [2] 2/-1/-1->1->3|3->1->2/-1/-1 [3] 2/-1/-1->1->-1|-1->1->2/-1/-1
<hostname>:3335161:3335218 [3] NCCL INFO Setting affinity for GPU 3 to ffffffff,ffffffff,ffffffff,ffffffff
<hostname>:3335159:3335217 [1] NCCL INFO Setting affinity for GPU 1 to ffffffff,ffffffff,ffffffff,ffffffff
<hostname>:3335158:3335216 [0] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/64
<hostname>:3335158:3335216 [0] NCCL INFO Trees [0] 3/-1/-1->0->-1|-1->0->3/-1/-1 [1] -1/-1/-1->0->3|3->0->-1/-1/-1 [2] 3/-1/-1->0->-1|-1->0->3/-1/-1 [3] -1/-1/-1->0->3|3->0->-1/-1/-1
<hostname>:3335158:3335216 [0] NCCL INFO Setting affinity for GPU 0 to ffffffff,ffffffff,ffffffff,ffffffff
<hostname>:3335160:3335219 [2] NCCL INFO Channel 00 : 2[c1000] -> 3[c2000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO Channel 00 : 1[82000] -> 2[c1000] via P2P/IPC
<hostname>:3335158:3335216 [0] NCCL INFO Channel 00 : 0[81000] -> 1[82000] via P2P/IPC
<hostname>:3335161:3335218 [3] NCCL INFO Channel 00 : 3[c2000] -> 0[81000] via P2P/IPC
<hostname>:3335160:3335219 [2] NCCL INFO Channel 00 : 2[c1000] -> 1[82000] via P2P/IPC
<hostname>:3335158:3335216 [0] NCCL INFO Channel 00 : 0[81000] -> 3[c2000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO Channel 00 : 1[82000] -> 3[c2000] via P2P/IPC
<hostname>:3335160:3335219 [2] NCCL INFO Channel 01 : 2[c1000] -> 1[82000] via P2P/IPC
<hostname>:3335161:3335218 [3] NCCL INFO Channel 00 : 3[c2000] -> 1[82000] via P2P/IPC
<hostname>:3335158:3335216 [0] NCCL INFO Channel 01 : 0[81000] -> 3[c2000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO Channel 01 : 1[82000] -> 0[81000] via P2P/IPC
<hostname>:3335161:3335218 [3] NCCL INFO Channel 01 : 3[c2000] -> 2[c1000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO Channel 01 : 1[82000] -> 2[c1000] via P2P/IPC
<hostname>:3335160:3335219 [2] NCCL INFO Channel 01 : 2[c1000] -> 3[c2000] via P2P/IPC
<hostname>:3335161:3335218 [3] NCCL INFO Channel 01 : 3[c2000] -> 0[81000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO Channel 02 : 1[82000] -> 2[c1000] via P2P/IPC
<hostname>:3335158:3335216 [0] NCCL INFO Channel 02 : 0[81000] -> 1[82000] via P2P/IPC
<hostname>:3335160:3335219 [2] NCCL INFO Channel 02 : 2[c1000] -> 3[c2000] via P2P/IPC
<hostname>:3335161:3335218 [3] NCCL INFO Channel 02 : 3[c2000] -> 0[81000] via P2P/IPC
<hostname>:3335158:3335216 [0] NCCL INFO Channel 02 : 0[81000] -> 3[c2000] via P2P/IPC
<hostname>:3335160:3335219 [2] NCCL INFO Channel 02 : 2[c1000] -> 1[82000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO Channel 02 : 1[82000] -> 3[c2000] via P2P/IPC
<hostname>:3335160:3335219 [2] NCCL INFO Channel 03 : 2[c1000] -> 1[82000] via P2P/IPC
<hostname>:3335161:3335218 [3] NCCL INFO Channel 02 : 3[c2000] -> 1[82000] via P2P/IPC
<hostname>:3335158:3335216 [0] NCCL INFO Channel 03 : 0[81000] -> 3[c2000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO Channel 03 : 1[82000] -> 0[81000] via P2P/IPC
<hostname>:3335161:3335218 [3] NCCL INFO Channel 03 : 3[c2000] -> 2[c1000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO Channel 03 : 1[82000] -> 2[c1000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO 4 coll channels, 4 p2p channels, 2 p2p channels per peer
<hostname>:3335160:3335219 [2] NCCL INFO Channel 03 : 2[c1000] -> 3[c2000] via P2P/IPC
<hostname>:3335159:3335217 [1] NCCL INFO comm 0x7f0190002e10 rank 1 nranks 4 cudaDev 1 busId 82000 - Init COMPLETE
<hostname>:3335161:3335218 [3] NCCL INFO Channel 03 : 3[c2000] -> 0[81000] via P2P/IPC
<hostname>:3335158:3335216 [0] NCCL INFO 4 coll channels, 4 p2p channels, 2 p2p channels per peer
<hostname>:3335158:3335216 [0] NCCL INFO comm 0x7f7008002e10 rank 0 nranks 4 cudaDev 0 busId 81000 - Init COMPLETE
<hostname>:3335158:3335158 [0] NCCL INFO Launch mode Parallel
<hostname>:3335160:3335219 [2] NCCL INFO 4 coll channels, 4 p2p channels, 2 p2p channels per peer
<hostname>:3335160:3335219 [2] NCCL INFO comm 0x7f4170002e10 rank 2 nranks 4 cudaDev 2 busId c1000 - Init COMPLETE
<hostname>:3335161:3335218 [3] NCCL INFO 4 coll channels, 4 p2p channels, 2 p2p channels per peer
<hostname>:3335161:3335218 [3] NCCL INFO comm 0x7f06cc002e10 rank 3 nranks 4 cudaDev 3 busId c2000 - Init COMPLETE