Stalling on Simple Distributed Barrier

I’m stalling on any kind of barrier between two GPUs in a minimal example with the nvidia pytorch container nvcr.io/nvidia/pytorch:25.03-py3 and there’s nothing really jumping out at me in the NCCL logs to a layman like me. I saw an older post that they had problems with the driver, is this probably the case for me? I initially used 570.86, and bumped up to 570.124. I’m running on bare metal kubernetes cluster with ubuntu 24.04, so kinda restricted to driver 570, and have to wait for 575 to be published in container form.

Edit: I also started with nvcr.io/nvidia/pytorch:25.02-py3 and tried the bump to 03 to fix this. My normal training code used to always be fine, and I didn’t notice any changes in the DDP training docs.

torchrun --nproc-per-node=gpu --standalone /mnt/user/test-ddp.py
import os
import torch
import torch.distributed as dist

if __name__ == "__main__":
    torch.cuda.set_device(f"cuda:{os.environ.get('LOCAL_RANK', 0)}")
    dist.init_process_group(backend="nccl")

    rank = dist.get_rank()
    print(f"Rank of the current process: {rank}")
    
    dist.barrier()
    print("All processes have reached the barrier.")

    dist.destroy_process_group()

W0528 12:31:53.174000 1 torch/distributed/run.py:763]
W0528 12:31:53.174000 1 torch/distributed/run.py:763] *****************************************
W0528 12:31:53.174000 1 torch/distributed/run.py:763] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0528 12:31:53.174000 1 torch/distributed/run.py:763] *****************************************
Rank of the current process: 0
Rank of the current process: 1
[rank1]:[W528 12:31:54.339872954 ProcessGroupNCCL.cpp:4782] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank0]:[W528 12:31:54.349033170 ProcessGroupNCCL.cpp:4782] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
ddp-deb-train-0:90:90 [0] NCCL INFO Bootstrap: Using eth0:192.168.19.109<0>
ddp-deb-train-0:90:90 [0] NCCL INFO cudaDriverVersion 12080
ddp-deb-train-0:90:90 [0] NCCL INFO NCCL version 2.25.1+cuda12.8
ddp-deb-train-0:90:90 [0] NCCL INFO Comm config Blocking set to 1
ddp-deb-train-0:91:91 [1] NCCL INFO cudaDriverVersion 12080
ddp-deb-train-0:91:91 [1] NCCL INFO Bootstrap: Using eth0:192.168.19.109<0>
ddp-deb-train-0:91:91 [1] NCCL INFO NCCL version 2.25.1+cuda12.8
ddp-deb-train-0:91:91 [1] NCCL INFO Comm config Blocking set to 1
ddp-deb-train-0:90:105 [0] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9)
ddp-deb-train-0:90:105 [0] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9)
ddp-deb-train-0:90:105 [0] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
ddp-deb-train-0:90:105 [0] NCCL INFO P2P plugin v9 IBext_v9
ddp-deb-train-0:90:105 [0] NCCL INFO NET/IB : No device found.
ddp-deb-train-0:90:105 [0] NCCL INFO NET/IB : Using [RO]; OOB eth0:192.168.19.109<0>
ddp-deb-train-0:90:105 [0] NCCL INFO NET/IB : No device found.
ddp-deb-train-0:90:105 [0] NCCL INFO NET/IB : Using [RO]; OOB eth0:192.168.19.109<0>
ddp-deb-train-0:90:105 [0] NCCL INFO NET/Socket : Using [0]eth0:192.168.19.109<0>
ddp-deb-train-0:90:105 [0] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so.
ddp-deb-train-0:90:105 [0] NCCL INFO Using network Socket
ddp-deb-train-0:90:105 [0] NCCL INFO ncclCommInitRankConfig comm 0x25bb5550 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 81000 commId 0x1d0b7e730eb15adf - Init START
ddp-deb-train-0:91:106 [1] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9)
ddp-deb-train-0:91:106 [1] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9)
ddp-deb-train-0:91:106 [1] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
ddp-deb-train-0:91:106 [1] NCCL INFO P2P plugin v9 IBext_v9
ddp-deb-train-0:91:106 [1] NCCL INFO NET/IB : No device found.
ddp-deb-train-0:91:106 [1] NCCL INFO NET/IB : Using [RO]; OOB eth0:192.168.19.109<0>
ddp-deb-train-0:91:106 [1] NCCL INFO NET/IB : No device found.
ddp-deb-train-0:91:106 [1] NCCL INFO NET/IB : Using [RO]; OOB eth0:192.168.19.109<0>
ddp-deb-train-0:91:106 [1] NCCL INFO NET/Socket : Using [0]eth0:192.168.19.109<0>
ddp-deb-train-0:91:106 [1] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so.
ddp-deb-train-0:91:106 [1] NCCL INFO Using network Socket
ddp-deb-train-0:91:106 [1] NCCL INFO ncclCommInitRankConfig comm 0xaa2ec70 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId e1000 commId 0x1d0b7e730eb15adf - Init START
ddp-deb-train-0:91:106 [1] NCCL INFO RAS client listening socket at ::1<28028>
ddp-deb-train-0:90:105 [0] NCCL INFO RAS client listening socket at ::1<28028>
ddp-deb-train-0:91:106 [1] NCCL INFO Bootstrap timings total 0.001023 (create 0.000037, send 0.000134, recv 0.000426, ring 0.000034, delay 0.000000)
ddp-deb-train-0:90:105 [0] NCCL INFO Bootstrap timings total 0.009177 (create 0.000035, send 0.000137, recv 0.008375, ring 0.000030, delay 0.000000)
ddp-deb-train-0:91:106 [1] NCCL INFO Setting affinity for GPU 1 to ffffffff,ffffffff,00000000,00000000,ffffffff,ffffffff,00000000,00000000
ddp-deb-train-0:90:105 [0] NCCL INFO Setting affinity for GPU 0 to ffffffff,ffffffff,00000000,00000000,ffffffff,ffffffff,00000000,00000000
ddp-deb-train-0:91:106 [1] NCCL INFO comm 0xaa2ec70 rank 1 nRanks 2 nNodes 1 localRanks 2 localRank 1 MNNVL 0
ddp-deb-train-0:90:105 [0] NCCL INFO comm 0x25bb5550 rank 0 nRanks 2 nNodes 1 localRanks 2 localRank 0 MNNVL 0
ddp-deb-train-0:91:106 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] 0/-1/-1->1->-1 [2] -1/-1/-1->1->0 [3] 0/-1/-1->1->-1
ddp-deb-train-0:90:105 [0] NCCL INFO Channel 00/04 : 0 1
ddp-deb-train-0:91:106 [1] NCCL INFO P2P Chunksize set to 131072
ddp-deb-train-0:90:105 [0] NCCL INFO Channel 01/04 : 0 1
ddp-deb-train-0:90:105 [0] NCCL INFO Channel 02/04 : 0 1
ddp-deb-train-0:90:105 [0] NCCL INFO Channel 03/04 : 0 1
ddp-deb-train-0:90:105 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] -1/-1/-1->0->1 [2] 1/-1/-1->0->-1 [3] -1/-1/-1->0->1
ddp-deb-train-0:90:105 [0] NCCL INFO P2P Chunksize set to 131072
ddp-deb-train-0:90:105 [0] NCCL INFO Check P2P Type intraNodeP2pSupport 1 directMode 0
ddp-deb-train-0:90:110 [0] NCCL INFO [Proxy Service] Device 0 CPU core 237
ddp-deb-train-0:91:109 [1] NCCL INFO [Proxy Service] Device 1 CPU core 109
ddp-deb-train-0:91:111 [1] NCCL INFO [Proxy Service UDS] Device 1 CPU core 219
ddp-deb-train-0:90:112 [0] NCCL INFO [Proxy Service UDS] Device 0 CPU core 75
ddp-deb-train-0:91:106 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
ddp-deb-train-0:91:106 [1] NCCL INFO 4 coll channels, 4 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
ddp-deb-train-0:90:105 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
ddp-deb-train-0:90:105 [0] NCCL INFO 4 coll channels, 4 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
ddp-deb-train-0:90:105 [0] NCCL INFO CC Off, workFifoBytes 1048576
ddp-deb-train-0:90:105 [0] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol.
ddp-deb-train-0:91:106 [1] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol.
ddp-deb-train-0:90:105 [0] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol.
ddp-deb-train-0:91:106 [1] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol.
ddp-deb-train-0:90:105 [0] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead.
ddp-deb-train-0:91:106 [1] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead.
ddp-deb-train-0:90:105 [0] NCCL INFO ncclCommInitRankConfig comm 0x25bb5550 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 81000 commId 0x1d0b7e730eb15adf - Init COMPLETE
ddp-deb-train-0:91:106 [1] NCCL INFO ncclCommInitRankConfig comm 0xaa2ec70 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId e1000 commId 0x1d0b7e730eb15adf - Init COMPLETE
ddp-deb-train-0:90:105 [0] NCCL INFO Init timings - ncclCommInitRankConfig: rank 0 nranks 2 total 0.14 (kernels 0.11, alloc 0.00, bootstrap 0.01, allgathers 0.00, topo 0.01, graphs 0.00, connections 0.00, rest 0.00)
ddp-deb-train-0:91:106 [1] NCCL INFO Init timings - ncclCommInitRankConfig: rank 1 nranks 2 total 0.13 (kernels 0.12, alloc 0.00, bootstrap 0.00, allgathers 0.00, topo 0.01, graphs 0.00, connections 0.00, rest 0.00)
ddp-deb-train-0:90:113 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/CUMEM
ddp-deb-train-0:90:113 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/CUMEM
ddp-deb-train-0:90:113 [0] NCCL INFO Channel 02/0 : 0[0] -> 1[1] via P2P/CUMEM
ddp-deb-train-0:90:113 [0] NCCL INFO Channel 03/0 : 0[0] -> 1[1] via P2P/CUMEM
ddp-deb-train-0:91:114 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/CUMEM
ddp-deb-train-0:91:114 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/CUMEM
ddp-deb-train-0:91:114 [1] NCCL INFO Channel 02/0 : 1[1] -> 0[0] via P2P/CUMEM
ddp-deb-train-0:91:114 [1] NCCL INFO Channel 03/0 : 1[1] -> 0[0] via P2P/CUMEM
ddp-deb-train-0:91:114 [1] NCCL INFO Connected all rings, use ring PXN 0 GDR 1
ddp-deb-train-0:90:113 [0] NCCL INFO Connected all rings, use ring PXN 0 GDR 1
root@ddp-deb-train-0:/workspace# nvidia-smi topo -m
        GPU0    GPU1    NIC0    NIC1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NODE    SYS     SYS     64-127,192-255  1               N/A
GPU1    NODE     X      SYS     SYS     64-127,192-255  1               N/A
NIC0    SYS     SYS      X      PIX
NIC1    SYS     SYS     PIX      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1

root@ddp-deb-train-0:/workspace# nvidia-smi
Wed May 28 12:37:59 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.124.06             Driver Version: 570.124.06     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX 5000 Ada Gene...    On  |   00000000:81:00.0 Off |                  Off |
| 30%   40C    P2             73W /  250W |     544MiB /  32760MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX 5000 Ada Gene...    On  |   00000000:E1:00.0 Off |                  Off |
| 30%   39C    P2             75W /  250W |     544MiB /  32760MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A              90      C   /usr/bin/python                         534MiB |
|    1   N/A  N/A              91      C   /usr/bin/python                         534MiB |
+-----------------------------------------------------------------------------------------+
1 Like

I cannot reproduce any issue using these containers and see:

torchrun --nproc_per_node=8 --standalone tmp.py 
W0528 15:28:07.034000 165 torch/distributed/run.py:763] 
W0528 15:28:07.034000 165 torch/distributed/run.py:763] *****************************************
W0528 15:28:07.034000 165 torch/distributed/run.py:763] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0528 15:28:07.034000 165 torch/distributed/run.py:763] *****************************************
Rank of the current process: 0
Rank of the current process: 1
[rank1]:[W528 15:28:16.564921279 ProcessGroupNCCL.cpp:4782] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
Rank of the current process: 7
...
All processes have reached the barrier.All processes have reached the barrier.All processes have reached the barrier.All processes have reached the barrier.All processes have reached the barrier.All processes have reached the barrier.
All processes have reached the barrier.
All processes have reached the barrier.
root@luna-prod-1309-au:/workspace/src# 
1 Like

I just tried it again this morning, and it seems to be a problem if not all GPUs on the node are assigned to this container. If I use all GPUs on a node it runs okay, but if only two are used (in the above example) then there is the stalling.

I can replicate this problem with CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc-per-node=gpu --standalone /mnt/user/test-ddp.py. Without CUDA_VISIBLE_DEVICES set, the same command runs fine when all GPUS on the node can be seen by the container.

# Use all GPUs on Node there is no problem
resources:
  limits:
  cpu: "12"
  memory: 32Gi
  nvidia.com/gpu: "5"
# Use only two GPUs on Node then I get stalling
resources:
  limits:
  cpu: "12"
  memory: 32Gi
  nvidia.com/gpu: "2"

Even resolving the W528 15:28:16.564921279 ProcessGroupNCCL.cpp:4782 warning by explicitly setting the device_id as hinted does not solve this problem.

# Initialize the process group
dist.init_process_group(
    backend="nccl",
    device_id=torch.device(f"cuda:{os.environ.get('LOCAL_RANK', 0)}"),
)

No reproduction on my side:

CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc-per-node=gpu --standalone tmp.py 
W0528 23:38:01.746000 646 torch/distributed/run.py:763] 
W0528 23:38:01.746000 646 torch/distributed/run.py:763] *****************************************
W0528 23:38:01.746000 646 torch/distributed/run.py:763] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0528 23:38:01.746000 646 torch/distributed/run.py:763] *****************************************
Rank of the current process: 0
Rank of the current process: 1
[rank1]:[W528 23:38:04.793574794 ProcessGroupNCCL.cpp:4782] [PG ID 0 PG GUID 0 Rank 1]  using GPU 1 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
[rank0]:[W528 23:38:04.928544033 ProcessGroupNCCL.cpp:4782] [PG ID 0 PG GUID 0 Rank 0]  using GPU 0 to perform barrier as devices used by this process are currently unknown. This can potentially cause a hang if this rank to GPU mapping is incorrect. Specify device_ids in barrier() to force use of a particular device, or call init_process_group() with a device_id.
All processes have reached the barrier.
All processes have reached the barrier.

Any other NCCL debug flags off the top of your head I can use to find why these workers are just blocking? I have to ecalate to sigkill as well, sigterm doesn’t stop them.

NCCL_DEBUG=INFO usually helps, but I did not see anything obvious. You could try to disable the cuMem usage just as another debugging step via NCCL_CUMEM_ENABLE=0 or general p2p access. Alternatively, you can also run pure nccl-tests to check if these work in your setup.

I think my NCCL is scuffed somehow, tried running the test example with 2 gpu.

root@ddp-deb-3090-train-0:/workspace/nccl-tests# ./build/all_reduce_perf -b 8 -e 128M -f 2 -g 2
# nThread 1 nGpus 2 minBytes 8 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0
#
# Using devices
#  Rank  0 Group  0 Pid   6747 on ddp-deb-3090-train-0 device  0 [0000:21:00] NVIDIA RTX 5000 Ada Generation
#  Rank  1 Group  0 Pid   6747 on ddp-deb-3090-train-0 device  1 [0000:41:00] NVIDIA RTX 5000 Ada Generation
ddp-deb-3090-train-0:6747:6747 [0] NCCL INFO Bootstrap: Using eth0:192.168.19.75<0>
ddp-deb-3090-train-0:6747:6747 [0] NCCL INFO cudaDriverVersion 12080
ddp-deb-3090-train-0:6747:6747 [0] NCCL INFO NCCL version 2.25.1+cuda12.8
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v9 (v9)
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v9)
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO P2P plugin v9 IBext_v9
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO NET/IB : No device found.
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO NET/IB : Using [RO]; OOB eth0:192.168.19.75<0>
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO NET/IB : No device found.
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO NET/IB : Using [RO]; OOB eth0:192.168.19.75<0>
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO NET/Socket : Using [0]eth0:192.168.19.75<0>
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so.
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Using network Socket
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so.
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO Using network Socket
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO ncclCommInitAll comm 0x629c72657970 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 21000 commId 0x5306469da442628d - Init START
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO ncclCommInitAll comm 0x629c726d72a0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 41000 commId 0x5306469da442628d - Init START
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO RAS client listening socket at ::1<28028>
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Bootstrap timings total 0.001210 (create 0.000038, send 0.000122, recv 0.000413, ring 0.000043, delay 0.000000)
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO Bootstrap timings total 0.001150 (create 0.000091, send 0.000315, recv 0.000223, ring 0.000034, delay 0.000000)
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO Setting affinity for GPU 1 to ffffffff,ffffffff,00000000,00000000,ffffffff,ffffffff
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Setting affinity for GPU 0 to ffffffff,ffffffff,00000000,00000000,ffffffff,ffffffff
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO comm 0x629c726d72a0 rank 1 nRanks 2 nNodes 1 localRanks 2 localRank 1 MNNVL 0
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO comm 0x629c72657970 rank 0 nRanks 2 nNodes 1 localRanks 2 localRank 0 MNNVL 0
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Channel 00/04 : 0 1
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Channel 01/04 : 0 1
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Channel 02/04 : 0 1
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Channel 03/04 : 0 1
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] -1/-1/-1->0->1 [2] 1/-1/-1->0->-1 [3] -1/-1/-1->0->1
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO P2P Chunksize set to 131072
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] 0/-1/-1->1->-1 [2] -1/-1/-1->1->0 [3] 0/-1/-1->1->-1
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO P2P Chunksize set to 131072
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Check P2P Type intraNodeP2pSupport 1 directMode 1
ddp-deb-3090-train-0:6747:6760 [1] NCCL INFO [Proxy Service] Device 1 CPU core 151
ddp-deb-3090-train-0:6747:6761 [1] NCCL INFO [Proxy Service UDS] Device 1 CPU core 167
ddp-deb-3090-train-0:6747:6762 [0] NCCL INFO [Proxy Service] Device 0 CPU core 191
ddp-deb-3090-train-0:6747:6763 [0] NCCL INFO [Proxy Service UDS] Device 0 CPU core 176
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO 4 coll channels, 4 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO 4 coll channels, 4 collnet channels, 0 nvls channels, 4 p2p channels, 2 p2p channels per peer
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO CC Off, workFifoBytes 1048576
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol.
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol.
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead.
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO ncclCommInitAll comm 0x629c72657970 rank 0 nranks 2 cudaDev 0 nvmlDev 0 busId 21000 commId 0x5306469da442628d - Init COMPLETE
ddp-deb-3090-train-0:6747:6756 [0] NCCL INFO Init timings - ncclCommInitAll: rank 0 nranks 2 total 0.27 (kernels 0.18, alloc 0.08, bootstrap 0.00, allgathers 0.00, topo 0.01, graphs 0.00, connections 0.00, rest 0.00)
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO ncclCommInitAll comm 0x629c726d72a0 rank 1 nranks 2 cudaDev 1 nvmlDev 1 busId 41000 commId 0x5306469da442628d - Init COMPLETE
ddp-deb-3090-train-0:6747:6757 [1] NCCL INFO Init timings - ncclCommInitAll: rank 1 nranks 2 total 0.27 (kernels 0.18, alloc 0.07, bootstrap 0.00, allgathers 0.00, topo 0.01, graphs 0.00, connections 0.00, rest 0.00)
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
ddp-deb-3090-train-0:6747:6765 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[1] via P2P/direct pointer
ddp-deb-3090-train-0:6747:6764 [1] NCCL INFO Channel 00/0 : 1[1] -> 0[0] via P2P/direct pointer
ddp-deb-3090-train-0:6747:6765 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[1] via P2P/direct pointer
ddp-deb-3090-train-0:6747:6764 [1] NCCL INFO Channel 01/0 : 1[1] -> 0[0] via P2P/direct pointer
ddp-deb-3090-train-0:6747:6765 [0] NCCL INFO Channel 02/0 : 0[0] -> 1[1] via P2P/direct pointer
ddp-deb-3090-train-0:6747:6764 [1] NCCL INFO Channel 02/0 : 1[1] -> 0[0] via P2P/direct pointer
ddp-deb-3090-train-0:6747:6765 [0] NCCL INFO Channel 03/0 : 0[0] -> 1[1] via P2P/direct pointer
ddp-deb-3090-train-0:6747:6764 [1] NCCL INFO Channel 03/0 : 1[1] -> 0[0] via P2P/direct pointer
ddp-deb-3090-train-0:6747:6765 [0] NCCL INFO Connected all rings, use ring PXN 0 GDR 1
ddp-deb-3090-train-0:6747:6764 [1] NCCL INFO Connected all rings, use ring PXN 0 GDR 1
root@ddp-deb-3090-train-0:/workspace# nvidia-smi
Wed May 28 23:57:33 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 570.124.06             Driver Version: 570.124.06     CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA RTX 5000 Ada Gene...    On  |   00000000:21:00.0 Off |                  Off |
| 30%   38C    P2             68W /  250W |     906MiB /  32760MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX 5000 Ada Gene...    On  |   00000000:41:00.0 Off |                  Off |
| 30%   41C    P2             71W /  250W |     906MiB /  32760MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA RTX 5000 Ada Gene...    On  |   00000000:61:00.0 Off |                  Off |
| 30%   28C    P8             17W /  250W |       5MiB /  32760MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            6747      C   ./build/all_reduce_perf                 896MiB |
|    1   N/A  N/A            6747      C   ./build/all_reduce_perf                 896MiB |
+-----------------------------------------------------------------------------------------+

I can get it to consistently run with NCCL_P2P_DISABLE=1.

It doesn’t matter if I select 2 GPU on the same or different PCI topology (a SYS or NODE pair), I can’t reliably get multi-gpu to run consistently without NCCL_P2P_DISABLE=1. Even if the GPUs are also enumerated 0,1 on the real host, or 2,4.

Full topology of the host system for clarity.

        GPU0    GPU1    GPU2    GPU3    GPU4    NIC0    NIC1    CPU Affinity    NUMA Affinity   GPU NUMA ID
GPU0     X      NODE    NODE    SYS     SYS     PHB     PHB     0-63,128-191    0               N/A
GPU1    NODE     X      NODE    SYS     SYS     NODE    NODE    0-63,128-191    0               N/A
GPU2    NODE    NODE     X      SYS     SYS     NODE    NODE    0-63,128-191    0               N/A
GPU3    SYS     SYS     SYS      X      NODE    SYS     SYS     64-127,192-255  1               N/A
GPU4    SYS     SYS     SYS     NODE     X      SYS     SYS     64-127,192-255  1               N/A
NIC0    PHB     NODE    NODE    SYS     SYS      X      PIX
NIC1    PHB     NODE    NODE    SYS     SYS     PIX      X 

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_0
  NIC1: mlx5_1

Assuming I understand your previous post correctly and you can reproduce the hang using NCCL itself, I would recommend creating an issue in their GitHub repository.

Actually I’ve just noticed a whole lot of crud in dmesg that is triggered when running these tests, I suppose it has something to do with my host system and a driver level issue.

[Thu May 29 00:16:02 2025] AMD-Vi: IOMMU Event log restarting
[Thu May 29 00:16:02 2025] amd_iommu_report_page_fault: 507 callbacks suppressed
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] nvidia 0000:21:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0028 address=0xd1ffc068 flags=0x0020]
[Thu May 29 00:16:02 2025] NVRM: GPU at PCI:0000:21:00: GPU-2cee2dbd-8a7f-b7da-4606-526b6699e7d5
[Thu May 29 00:16:02 2025] NVRM: GPU Board Serial Number: 1323623060244
[Thu May 29 00:16:02 2025] NVRM: Xid (PCI:0000:21:00): 31, pid=1847614, name=all_reduce_perf, Ch 00000009, intr 00000000. MMU Fault: ENGINE GRAPHICS GPC0 GPCCLIENT_T1_2 faulted @ 0x0_00000000. Fault is of type FAULT_PDE ACCESS_TYPE_VIRT_READ
[Thu May 29 00:19:00 2025] nvidia 0000:61:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x000a address=0xe5b6e068 flags=0x0020]
[Thu May 29 00:19:00 2025] nvidia 0000:61:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x000a address=0xe5b6e070 flags=0x0020]
[Thu May 29 00:26:02 2025] amd_iommu_report_page_fault: 508 callbacks suppressed
[Thu May 29 00:26:02 2025] nvidia 0000:41:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0014 address=0xe895e068 flags=0x0020]
[Thu May 29 00:26:02 2025] nvidia 0000:41:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0014 address=0xe895e070 flags=0x0020]
[Thu May 29 00:28:15 2025] nvidia 0000:e1:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0034 address=0xe0fee068 flags=0x0020]
[Thu May 29 00:28:15 2025] nvidia 0000:e1:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x0034 address=0xe0fee070 flags=0x0020]
[Thu May 29 00:28:33 2025] nvidia 0000:61:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x000a address=0xe5b4e068 flags=0x0020]
[Thu May 29 00:28:33 2025] nvidia 0000:61:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x000a address=0xe5b4e070 flags=0x0020]
[Thu May 29 00:39:30 2025] nvidia 0000:61:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x000a address=0xe5b2e068 flags=0x0020]
[Thu May 29 00:39:30 2025] nvidia 0000:61:00.0: AMD-Vi: Event logged [IO_PAGE_FAULT domain=0x000a address=0xe5b2e070 flags=0x0020]

Disable IOMMU as described here.

Took me a while to be able to test as disabling IOMMU required a reboot and changes in BIOS and security.canonical.com was faulty/down all day preventing gpu-operator from reinstalling the driver, preventing me from testing this…but I can confirm when NCCL tries to use P2P/CUMEM it now works fine, no stalling.

Its interesting to note that it seems that only pairs of GPUS invoke P2P/CUMEM, no matter if they attached to the same host or not, and using more than 2 gpu always uses SHM/direct/direct between all GPUs.

Triple on same host

ddp-deb-3090-1-train-0:2034:2067 [2] NCCL INFO Channel 00 : 2[2] -> 0[0] via SHM/direct/direct
ddp-deb-3090-1-train-0:2034:2067 [2] NCCL INFO Channel 01 : 2[2] -> 0[0] via SHM/direct/direct
ddp-deb-3090-1-train-0:2033:2066 [1] NCCL INFO Channel 00 : 1[1] -> 2[2] via SHM/direct/direct
ddp-deb-3090-1-train-0:2034:2067 [2] NCCL INFO Channel 02 : 2[2] -> 0[0] via SHM/direct/direct
ddp-deb-3090-1-train-0:2033:2066 [1] NCCL INFO Channel 01 : 1[1] -> 2[2] via SHM/direct/direct
ddp-deb-3090-1-train-0:2034:2067 [2] NCCL INFO Channel 03 : 2[2] -> 0[0] via SHM/direct/direct
ddp-deb-3090-1-train-0:2033:2066 [1] NCCL INFO Channel 02 : 1[1] -> 2[2] via SHM/direct/direct
ddp-deb-3090-1-train-0:2033:2066 [1] NCCL INFO Channel 03 : 1[1] -> 2[2] via SHM/direct/direct
ddp-deb-3090-1-train-0:2032:2068 [0] NCCL INFO Channel 00 : 0[0] -> 1[1] via SHM/direct/direct
ddp-deb-3090-1-train-0:2032:2068 [0] NCCL INFO Channel 01 : 0[0] -> 1[1] via SHM/direct/direct
ddp-deb-3090-1-train-0:2032:2068 [0] NCCL INFO Channel 02 : 0[0] -> 1[1] via SHM/direct/direct
ddp-deb-3090-1-train-0:2032:2068 [0] NCCL INFO Channel 03 : 0[0] -> 1[1] via SHM/direct/direct

Pair on different host

ddp-deb-3090-1-train-0:1925:1947 [1] NCCL INFO Channel 00/0 : 1[4] -> 0[0] via P2P/CUMEM
ddp-deb-3090-1-train-0:1925:1947 [1] NCCL INFO Channel 01/0 : 1[4] -> 0[0] via P2P/CUMEM
ddp-deb-3090-1-train-0:1924:1948 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[4] via P2P/CUMEM
ddp-deb-3090-1-train-0:1924:1948 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[4] via P2P/CUMEM
1 Like

Great! Thanks for confirming this fixes the issue.