DDP hangs upon creation

Hi.
I’m trying to use DDP on two nodes, but the DDP creation hangs forever. The code is like this:

import torch
import torch.nn as nn
import torch.distributed as dist
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import datetime

os.environ['MASTER_ADDR']='$myip'
os.environ['MASTER_PORT']='7777'
# os.environ['NCCL_BLOCKING_WAIT']='1'
os.environ['NCCL_ASYNC_ERROR_HANDLING']='1'


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

The following lines are different for each node:

dist.init_process_group(backend='nccl', timeout=datetime.timedelta(0, 10), world_size=2, rank=0) # rank=0 for $myip node, rank=1 for the other node

model = ToyModel().to(0)
ddp_model = DDP(model, device_ids=[0], output_device=0) # This is where hangs.

One of the nodes would show this:

In [4]: model = ToyModel().to(0)
   ...: ddp_model = DDP(model, device_ids=[0], output_device=0)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-7fbd4245ff44> in <module>
      1 model = ToyModel().to(0)
----> 2 ddp_model = DDP(model, device_ids=[0], output_device=0)

~/bin/anaconda3/envs/torch/lib/python3.8/site-packages/torch/nn/parallel/distributed.py in __init__(self, module, device_ids, output_device, dim, broadcast_buffers, process_group, bucket_cap_mb, find_unused_parameters, check_reduction, gradient_as_bucket_view)
    576         parameters, expect_sparse_gradient = self._build_params_for_reducer()
    577         # Verify model equivalence.
--> 578         dist._verify_model_across_ranks(self.process_group, parameters)
    579         # Sync params and buffers. Ensures all DDP models start off at the same value.
    580         self._sync_params_and_buffers(authoritative_rank=0)

RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1634272172048/work/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:957, unhandled system error, NCCL version 21.0.3
ncclSystemError: System call (socket, malloc, munmap, etc) failed.

Any advices? Thanks~

1 Like

Hey @Musoy_King, looks like NCCL broadcast crashed. Can you try if directly calling dist.broadcast would fail too?

Also, looks like you are using ipython or notebook. Can you try to directly use python to run the script on the two nodes?

Hi, I’ve encountered the exact same problem.
Can you share how you have dealt with this error?

Can you run with NCCL_DEBUG=INFO and share the logs? That would provide more detailed information about what went wrong.

Hello, yesterday one of my machines had the same problem. It cannot even create a single node ddp model.

Here is the code. I made a little change to fit torchrun and enabled verbose logs.

import torch
import torch.nn as nn
import torch.distributed as dist
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import datetime

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))

os.environ['NCCL_DEBUG'] = 'INFO'
os.environ['NCCL_DEBUG_SUBSYS'] = 'ALL'

dist.init_process_group(backend='nccl', timeout=datetime.timedelta(0, 10))
local_rank = int(os.environ.get("LOCAL_RANK"))
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE"))
rank = int(os.environ.get("RANK"))
world_size = int(os.environ.get("WORLD_SIZE"))

device = torch.device("cuda", local_rank)
model = ToyModel().to(device)
print("Creating ddp model")
ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
print("Created ddp model")

Here is the output message @pritamdamania87

torchrun --rdzv_backend=c10d --rdzv_id=456 --rdzv_endpoint=localhost:0 --nnodes=1 --node_rank=0 --nproc_per_node=2 testdist.py
master_addr is only used for static rdzv_backend and when rdzv_endpoint is not specified.
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
*****************************************
Creating ddp model
Creating ddp model
lyra:289864:289864 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to enp194s0f0
lyra:289864:289864 [0] NCCL INFO NCCL_SOCKET_IFNAME set to enp194s0f0
lyra:289864:289864 [0] NCCL INFO Bootstrap : Using enp194s0f0:192.168.1.177<0>
lyra:289864:289864 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
lyra:289864:289864 [0] NCCL INFO cudaDriverVersion 12010
NCCL version 2.14.3+cuda11.7
lyra:289864:289864 [0] NCCL INFO init.cc:1147 Cuda Host Alloc Size 4 pointer 0x7f09fb600000
lyra:289864:289883 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to enp194s0f0
lyra:289864:289883 [0] NCCL INFO NET/IB : No device found.
lyra:289864:289883 [0] NCCL INFO NCCL_SOCKET_IFNAME set by environment to enp194s0f0
lyra:289864:289883 [0] NCCL INFO NET/Socket : Using [0]enp194s0f0:192.168.1.177<0>
lyra:289864:289883 [0] NCCL INFO Using network Socket
lyra:289865:289865 [1] NCCL INFO cudaDriverVersion 12010
lyra:289865:289865 [1] NCCL INFO NCCL_SOCKET_IFNAME set by environment to enp194s0f0
lyra:289865:289865 [1] NCCL INFO NCCL_SOCKET_IFNAME set to enp194s0f0
lyra:289865:289865 [1] NCCL INFO Bootstrap : Using enp194s0f0:192.168.1.177<0>
lyra:289865:289865 [1] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
lyra:289865:289865 [1] NCCL INFO init.cc:1147 Cuda Host Alloc Size 4 pointer 0x7f77dc400000
lyra:289865:289904 [1] NCCL INFO NCCL_SOCKET_IFNAME set by environment to enp194s0f0
lyra:289865:289904 [1] NCCL INFO NET/IB : No device found.
lyra:289865:289904 [1] NCCL INFO NCCL_SOCKET_IFNAME set by environment to enp194s0f0
lyra:289865:289904 [1] NCCL INFO NET/Socket : Using [0]enp194s0f0:192.168.1.177<0>
lyra:289865:289904 [1] NCCL INFO Using network Socket
lyra:289865:289904 [1] NCCL INFO NET/Socket : GPU Direct RDMA Disabled for HCA 0 'enp194s0f0'
lyra:289864:289883 [0] NCCL INFO NET/Socket : GPU Direct RDMA Disabled for HCA 0 'enp194s0f0'
lyra:289865:289904 [1] NCCL INFO transport/p2p.cc:151 Cuda Alloc Size 2097152 pointer 0x7f77dd000000
lyra:289864:289883 [0] NCCL INFO transport/p2p.cc:151 Cuda Alloc Size 2097152 pointer 0x7f09fc200000
lyra:289865:289904 [1] NCCL INFO === System : maxBw 24.0 totalBw 24.0 ===
lyra:289865:289904 [1] NCCL INFO CPU/0 (1/2/-1)
lyra:289865:289904 [1] NCCL INFO + SYS[5000.0] - CPU/1
lyra:289865:289904 [1] NCCL INFO + PCI[24.0] - GPU/1000 (0)
lyra:289865:289904 [1] NCCL INFO + PCI[24.0] - GPU/24000 (1)
lyra:289865:289904 [1] NCCL INFO CPU/1 (1/2/-1)
lyra:289865:289904 [1] NCCL INFO + SYS[5000.0] - CPU/0
lyra:289865:289904 [1] NCCL INFO + PCI[3.0] - NIC/C2000
lyra:289865:289904 [1] NCCL INFO ==========================================
lyra:289865:289904 [1] NCCL INFO GPU/1000 :GPU/1000 (0/5000.000000/LOC) GPU/24000 (2/24.000000/PHB) CPU/0 (1/24.000000/PHB) CPU/1 (2/24.000000/SYS)
lyra:289865:289904 [1] NCCL INFO GPU/24000 :GPU/1000 (2/24.000000/PHB) GPU/24000 (0/5000.000000/LOC) CPU/0 (1/24.000000/PHB) CPU/1 (2/24.000000/SYS)
lyra:289865:289904 [1] NCCL INFO Setting affinity for GPU 1 to ffffffff,ffffffff,00000000,00000000,ffffffff,ffffffff
lyra:289864:289883 [0] NCCL INFO === System : maxBw 24.0 totalBw 24.0 ===
lyra:289864:289883 [0] NCCL INFO CPU/0 (1/2/-1)
lyra:289864:289883 [0] NCCL INFO + SYS[5000.0] - CPU/1
lyra:289864:289883 [0] NCCL INFO + PCI[24.0] - GPU/1000 (0)
lyra:289864:289883 [0] NCCL INFO + PCI[24.0] - GPU/24000 (1)
lyra:289864:289883 [0] NCCL INFO CPU/1 (1/2/-1)
lyra:289864:289883 [0] NCCL INFO + SYS[5000.0] - CPU/0
lyra:289864:289883 [0] NCCL INFO + PCI[3.0] - NIC/C2000
lyra:289864:289883 [0] NCCL INFO ==========================================
lyra:289864:289883 [0] NCCL INFO GPU/1000 :GPU/1000 (0/5000.000000/LOC) GPU/24000 (2/24.000000/PHB) CPU/0 (1/24.000000/PHB) CPU/1 (2/24.000000/SYS)
lyra:289864:289883 [0] NCCL INFO GPU/24000 :GPU/1000 (2/24.000000/PHB) GPU/24000 (0/5000.000000/LOC) CPU/0 (1/24.000000/PHB) CPU/1 (2/24.000000/SYS)
lyra:289865:289904 [1] NCCL INFO Pattern 4, crossNic 0, nChannels 2, bw 12.000000/12.000000, type PHB/PIX, sameChannels 1
lyra:289865:289904 [1] NCCL INFO  0 : GPU/0 GPU/1
lyra:289864:289883 [0] NCCL INFO Setting affinity for GPU 0 to ffffffff,ffffffff,00000000,00000000,ffffffff,ffffffff
lyra:289865:289904 [1] NCCL INFO  1 : GPU/0 GPU/1
lyra:289865:289904 [1] NCCL INFO Pattern 1, crossNic 0, nChannels 2, bw 22.000000/22.000000, type PHB/PIX, sameChannels 0
lyra:289865:289904 [1] NCCL INFO  0 : GPU/0 GPU/1
lyra:289865:289904 [1] NCCL INFO  1 : GPU/1 GPU/0
lyra:289865:289904 [1] NCCL INFO Pattern 3, crossNic 0, nChannels 2, bw 22.000000/22.000000, type PHB/PIX, sameChannels 0
lyra:289865:289904 [1] NCCL INFO  0 : GPU/0 GPU/1
lyra:289865:289904 [1] NCCL INFO  1 : GPU/1 GPU/0
lyra:289864:289883 [0] NCCL INFO Pattern 4, crossNic 0, nChannels 2, bw 12.000000/12.000000, type PHB/PIX, sameChannels 1
lyra:289864:289883 [0] NCCL INFO  0 : GPU/0 GPU/1
lyra:289864:289883 [0] NCCL INFO  1 : GPU/0 GPU/1
lyra:289864:289883 [0] NCCL INFO Pattern 1, crossNic 0, nChannels 2, bw 22.000000/22.000000, type PHB/PIX, sameChannels 0
lyra:289864:289883 [0] NCCL INFO  0 : GPU/0 GPU/1
lyra:289864:289883 [0] NCCL INFO  1 : GPU/1 GPU/0
lyra:289864:289883 [0] NCCL INFO Pattern 3, crossNic 0, nChannels 2, bw 22.000000/22.000000, type PHB/PIX, sameChannels 0
lyra:289864:289883 [0] NCCL INFO  0 : GPU/0 GPU/1
lyra:289864:289883 [0] NCCL INFO  1 : GPU/1 GPU/0
lyra:289865:289904 [1] NCCL INFO Tree 0 : 0 -> 1 -> -1/-1/-1
lyra:289865:289904 [1] NCCL INFO Tree 2 : 0 -> 1 -> -1/-1/-1
lyra:289864:289883 [0] NCCL INFO Tree 0 : -1 -> 0 -> 1/-1/-1
lyra:289865:289904 [1] NCCL INFO Tree 1 : -1 -> 1 -> 0/-1/-1
lyra:289864:289883 [0] NCCL INFO Tree 2 : -1 -> 0 -> 1/-1/-1
lyra:289865:289904 [1] NCCL INFO Tree 3 : -1 -> 1 -> 0/-1/-1
lyra:289864:289883 [0] NCCL INFO Tree 1 : 1 -> 0 -> -1/-1/-1
lyra:289864:289883 [0] NCCL INFO Tree 3 : 1 -> 0 -> -1/-1/-1
lyra:289865:289904 [1] NCCL INFO Ring 00 : 0 -> 1 -> 0
lyra:289864:289883 [0] NCCL INFO Channel 00/04 :    0   1
lyra:289865:289904 [1] NCCL INFO Ring 01 : 0 -> 1 -> 0
lyra:289864:289883 [0] NCCL INFO Channel 01/04 :    0   1
lyra:289865:289904 [1] NCCL INFO Ring 02 : 0 -> 1 -> 0
lyra:289864:289883 [0] NCCL INFO Channel 02/04 :    0   1
lyra:289865:289904 [1] NCCL INFO Ring 03 : 0 -> 1 -> 0
lyra:289864:289883 [0] NCCL INFO Channel 03/04 :    0   1
lyra:289865:289904 [1] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] 0/-1/-1->1->-1 [2] -1/-1/-1->1->0 [3] 0/-1/-1->1->-1
lyra:289864:289883 [0] NCCL INFO Ring 00 : 1 -> 0 -> 1
lyra:289865:289904 [1] NCCL INFO misc/utils.cc:235 memory stack hunk malloc(65536)
lyra:289864:289883 [0] NCCL INFO Ring 01 : 1 -> 0 -> 1
lyra:289864:289883 [0] NCCL INFO Ring 02 : 1 -> 0 -> 1
lyra:289864:289883 [0] NCCL INFO Ring 03 : 1 -> 0 -> 1
lyra:289864:289883 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] -1/-1/-1->0->1 [2] 1/-1/-1->0->-1 [3] -1/-1/-1->0->1
lyra:289864:289883 [0] NCCL INFO misc/utils.cc:235 memory stack hunk malloc(65536)
lyra:289865:289904 [1] NCCL INFO channel.cc:23 Cuda Alloc Size 1152 pointer 0x7f77dd000000
lyra:289865:289904 [1] NCCL INFO channel.cc:27 Cuda Alloc Size 8 pointer 0x7f77dd000600
lyra:289864:289883 [0] NCCL INFO channel.cc:23 Cuda Alloc Size 1152 pointer 0x7f09fc200000
lyra:289865:289904 [1] NCCL INFO channel.cc:23 Cuda Alloc Size 1152 pointer 0x7f77dd000800
lyra:289865:289904 [1] NCCL INFO channel.cc:27 Cuda Alloc Size 8 pointer 0x7f77dd000e00
lyra:289865:289904 [1] NCCL INFO channel.cc:23 Cuda Alloc Size 1152 pointer 0x7f77dd001000
lyra:289865:289904 [1] NCCL INFO channel.cc:27 Cuda Alloc Size 8 pointer 0x7f77dd001600
lyra:289865:289904 [1] NCCL INFO channel.cc:23 Cuda Alloc Size 1152 pointer 0x7f77dd001800
lyra:289865:289904 [1] NCCL INFO channel.cc:27 Cuda Alloc Size 8 pointer 0x7f77dd001e00
lyra:289864:289883 [0] NCCL INFO channel.cc:27 Cuda Alloc Size 8 pointer 0x7f09fc200600
lyra:289864:289883 [0] NCCL INFO channel.cc:23 Cuda Alloc Size 1152 pointer 0x7f09fc200800
lyra:289864:289883 [0] NCCL INFO channel.cc:27 Cuda Alloc Size 8 pointer 0x7f09fc200e00
lyra:289864:289883 [0] NCCL INFO channel.cc:23 Cuda Alloc Size 1152 pointer 0x7f09fc201000
lyra:289864:289883 [0] NCCL INFO channel.cc:27 Cuda Alloc Size 8 pointer 0x7f09fc201600
lyra:289864:289883 [0] NCCL INFO channel.cc:23 Cuda Alloc Size 1152 pointer 0x7f09fc201800
lyra:289864:289883 [0] NCCL INFO channel.cc:27 Cuda Alloc Size 8 pointer 0x7f09fc201e00
lyra:289865:289905 [1] NCCL INFO Mem Realloc old size 0, new size 8 pointer 0x7f77d0004c80
lyra:289865:289904 [1] NCCL INFO Connection to proxy localRank 1 -> connection 0x7f77d0004ca0
lyra:289865:289905 [1] NCCL INFO New proxy recv connection 0 from local rank 1, transport 0
lyra:289864:289906 [0] NCCL INFO Mem Realloc old size 0, new size 8 pointer 0x7f09bc000b70
lyra:289864:289883 [0] NCCL INFO Connection to proxy localRank 0 -> connection 0x7f09bc004d30
lyra:289864:289906 [0] NCCL INFO New proxy recv connection 0 from local rank 0, transport 0
lyra:289865:289905 [1] NCCL INFO transport/p2p.cc:449 Cuda Alloc Size 10485760 pointer 0x7f77dd200000
lyra:289865:289904 [1] NCCL INFO Connection to proxy localRank 1 -> connection 0x7f77d0004ce0
lyra:289865:289905 [1] NCCL INFO New proxy recv connection 1 from local rank 1, transport 0
lyra:289864:289906 [0] NCCL INFO transport/p2p.cc:449 Cuda Alloc Size 10485760 pointer 0x7f09fc400000
lyra:289864:289883 [0] NCCL INFO Connection to proxy localRank 0 -> connection 0x7f09bc004d70
lyra:289864:289906 [0] NCCL INFO New proxy recv connection 1 from local rank 0, transport 0
lyra:289865:289905 [1] NCCL INFO transport/p2p.cc:449 Cuda Alloc Size 10485760 pointer 0x7f77d6000000
lyra:289865:289904 [1] NCCL INFO Connection to proxy localRank 1 -> connection 0x7f77d0004d20
lyra:289865:289905 [1] NCCL INFO New proxy recv connection 2 from local rank 1, transport 0
lyra:289864:289906 [0] NCCL INFO transport/p2p.cc:449 Cuda Alloc Size 10485760 pointer 0x7f09fce00000
lyra:289864:289883 [0] NCCL INFO Connection to proxy localRank 0 -> connection 0x7f09bc004db0
lyra:289864:289906 [0] NCCL INFO New proxy recv connection 2 from local rank 0, transport 0
lyra:289865:289905 [1] NCCL INFO transport/p2p.cc:449 Cuda Alloc Size 10485760 pointer 0x7f77d6a00000
lyra:289865:289904 [1] NCCL INFO Connection to proxy localRank 1 -> connection 0x7f77d0004d60
lyra:289865:289905 [1] NCCL INFO New proxy recv connection 3 from local rank 1, transport 0
lyra:289864:289906 [0] NCCL INFO transport/p2p.cc:449 Cuda Alloc Size 10485760 pointer 0x7f09fd800000
lyra:289864:289883 [0] NCCL INFO Connection to proxy localRank 0 -> connection 0x7f09bc004df0
lyra:289864:289906 [0] NCCL INFO New proxy recv connection 3 from local rank 0, transport 0
lyra:289865:289905 [1] NCCL INFO transport/p2p.cc:449 Cuda Alloc Size 10485760 pointer 0x7f77d7400000
lyra:289865:289904 [1] NCCL INFO Channel 00/0 : 1[24000] -> 0[1000] via P2P/IPC
lyra:289865:289904 [1] NCCL INFO Connection to proxy localRank 1 -> connection 0x7f77d0004da0
lyra:289865:289905 [1] NCCL INFO New proxy send connection 4 from local rank 1, transport 0
lyra:289864:289906 [0] NCCL INFO transport/p2p.cc:449 Cuda Alloc Size 10485760 pointer 0x7f09fe200000
lyra:289864:289883 [0] NCCL INFO Channel 00/0 : 0[1000] -> 1[24000] via P2P/IPC
lyra:289864:289883 [0] NCCL INFO Connection to proxy localRank 0 -> connection 0x7f09bc004e30
lyra:289864:289906 [0] NCCL INFO New proxy send connection 4 from local rank 0, transport 0
lyra:289865:289905 [1] NCCL INFO transport/p2p.cc:430 Cuda Alloc Size 2097152 pointer 0x7f77d7e00000
lyra:289865:289904 [1] NCCL INFO Channel 01/0 : 1[24000] -> 0[1000] via P2P/IPC
lyra:289865:289904 [1] NCCL INFO Connection to proxy localRank 1 -> connection 0x7f77d0004de0
lyra:289865:289905 [1] NCCL INFO New proxy send connection 5 from local rank 1, transport 0
lyra:289864:289906 [0] NCCL INFO transport/p2p.cc:430 Cuda Alloc Size 2097152 pointer 0x7f09fec00000
lyra:289864:289883 [0] NCCL INFO Channel 01/0 : 0[1000] -> 1[24000] via P2P/IPC
lyra:289864:289883 [0] NCCL INFO Connection to proxy localRank 0 -> connection 0x7f09bc004e70
lyra:289864:289906 [0] NCCL INFO New proxy send connection 5 from local rank 0, transport 0
lyra:289865:289905 [1] NCCL INFO transport/p2p.cc:430 Cuda Alloc Size 2097152 pointer 0x7f77ddc00000
lyra:289865:289904 [1] NCCL INFO Channel 02/0 : 1[24000] -> 0[1000] via P2P/IPC
lyra:289865:289904 [1] NCCL INFO Connection to proxy localRank 1 -> connection 0x7f77d0004e20
lyra:289865:289905 [1] NCCL INFO New proxy send connection 6 from local rank 1, transport 0
lyra:289864:289906 [0] NCCL INFO transport/p2p.cc:430 Cuda Alloc Size 2097152 pointer 0x7f09fee00000
lyra:289864:289883 [0] NCCL INFO Channel 02/0 : 0[1000] -> 1[24000] via P2P/IPC
lyra:289864:289906 [0] NCCL INFO New proxy send connection 6 from local rank 0, transport 0
lyra:289864:289883 [0] NCCL INFO Connection to proxy localRank 0 -> connection 0x7f09bc004eb0
lyra:289865:289905 [1] NCCL INFO transport/p2p.cc:430 Cuda Alloc Size 2097152 pointer 0x7f77dde00000
lyra:289865:289904 [1] NCCL INFO Channel 03/0 : 1[24000] -> 0[1000] via P2P/IPC
lyra:289865:289904 [1] NCCL INFO Connection to proxy localRank 1 -> connection 0x7f77d0004e60
lyra:289865:289905 [1] NCCL INFO New proxy send connection 7 from local rank 1, transport 0
lyra:289864:289906 [0] NCCL INFO transport/p2p.cc:430 Cuda Alloc Size 2097152 pointer 0x7f09ff000000
lyra:289864:289883 [0] NCCL INFO Channel 03/0 : 0[1000] -> 1[24000] via P2P/IPC
lyra:289864:289883 [0] NCCL INFO Connection to proxy localRank 0 -> connection 0x7f09bc004ef0
lyra:289864:289906 [0] NCCL INFO New proxy send connection 7 from local rank 0, transport 0
lyra:289865:289905 [1] NCCL INFO transport/p2p.cc:430 Cuda Alloc Size 2097152 pointer 0x7f77ce000000
lyra:289864:289906 [0] NCCL INFO transport/p2p.cc:430 Cuda Alloc Size 2097152 pointer 0x7f09ff200000
lyra:289864:289883 [0] NCCL INFO Connected all rings
lyra:289864:289883 [0] NCCL INFO Connected all trees
lyra:289864:289883 [0] NCCL INFO Latency/AlgBw |    Tree/    LL |    Tree/ LL128 |    Tree/Simple |    Ring/    LL |    Ring/ LL128 |    Ring/Simple | CollNetDirect/    LL | CollNetDirect/ LL128 | CollNetDirect/Simple | CollNetChain/    LL | CollNetChain/ LL128 | CollNetChain/Simple |
lyra:289864:289883 [0] NCCL INFO  Max NThreads |            512 |            640 |            512 |            512 |            640 |            512 |              0 |              0 |            512 |              0 |              0 |            512 |
lyra:289864:289883 [0] NCCL INFO     Broadcast |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     4.6/   8.0 |    12.5/   0.0 |    14.1/  24.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |
lyra:289864:289883 [0] NCCL INFO        Reduce |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     4.6/   6.0 |    12.5/   0.0 |    14.1/  24.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |
lyra:289864:289883 [0] NCCL INFO     AllGather |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     4.6/  16.0 |    12.5/   0.0 |    14.1/  48.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |
lyra:289864:289883 [0] NCCL INFO ReduceScatter |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     4.6/  16.0 |    12.5/   0.0 |    14.1/  48.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |     0.0/   0.0 |
lyra:289864:289883 [0] NCCL INFO     AllReduce |     6.4/   5.3 |     8.2/   0.0 |    56.0/  20.2 |     5.6/   6.0 |    15.0/   0.0 |    19.8/  24.0 |     5.4/   0.0 |     5.4/   0.0 |    27.7/   0.0 |     4.4/   0.0 |     4.4/   0.0 |    16.0/   0.0 |
lyra:289864:289883 [0] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
lyra:289864:289883 [0] NCCL INFO 4 coll channels, 4 p2p channels, 2 p2p channels per peer
lyra:289864:289906 [0] NCCL INFO Allocated 4194656 bytes of shared memory in /dev/shm/nccl-qbUg3Y

lyra:289865:289904 [1] NCCL INFO Connected all rings
lyra:289865:289904 [1] NCCL INFO Connected all trees
lyra:289865:289904 [1] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 512 | 512
lyra:289865:289904 [1] NCCL INFO 4 coll channels, 4 p2p channels, 2 p2p channels per peer
lyra:289865:289905 [1] NCCL INFO Allocated 4194656 bytes of shared memory in /dev/shm/nccl-wS5jXh

lyra:289864:289906 [0] NCCL INFO New proxy send connection 8 from local rank 0, transport 2
lyra:289865:289905 [1] NCCL INFO New proxy send connection 8 from local rank 1, transport 2
lyra:289864:289883 [0] NCCL INFO Connection to proxy localRank 0 -> connection 0x7f09bc004f30
lyra:289865:289904 [1] NCCL INFO Connection to proxy localRank 1 -> connection 0x7f77d0004ea0
lyra:289865:289904 [1] NCCL INFO init.cc:367 Cuda Alloc Size 5168 pointer 0x7f77dd002000
lyra:289864:289883 [0] NCCL INFO init.cc:367 Cuda Alloc Size 5168 pointer 0x7f09fc202000
lyra:289864:289906 [0] NCCL INFO transport/net.cc:376 Cuda Alloc Size 8388608 pointer 0x7f0a02400000
lyra:289865:289905 [1] NCCL INFO transport/net.cc:376 Cuda Alloc Size 8388608 pointer 0x7f77cd200000
lyra:289864:289883 [0] NCCL INFO init.cc:392 Cuda Host Alloc Size 33554432 pointer 0x7f0a02c00000
lyra:289864:289883 [0] NCCL INFO init.cc:398 Cuda Host Alloc Size 128 pointer 0x7f09fb600200
lyra:289864:289883 [0] NCCL INFO comm 0x580c6ef0 rank 0 nranks 2 cudaDev 0 busId 1000 - Init COMPLETE
lyra:289864:289864 [0] NCCL INFO AllGather: opCount 0 sendbuff 0x7f09f7800800 recvbuff 0x7f09f7800e00 count 8 datatype 0 op 0 root 0 comm 0x580c6ef0 [nranks=2] stream 0x57ed16b0
lyra:289864:289864 [0] NCCL INFO misc/utils.cc:235 memory stack hunk malloc(65536)
lyra:289865:289904 [1] NCCL INFO init.cc:392 Cuda Host Alloc Size 33554432 pointer 0x7f77c6000000
lyra:289865:289904 [1] NCCL INFO init.cc:398 Cuda Host Alloc Size 128 pointer 0x7f77dc400200
lyra:289865:289904 [1] NCCL INFO comm 0x872216e0 rank 1 nranks 2 cudaDev 1 busId 24000 - Init COMPLETE
lyra:289865:289865 [1] NCCL INFO AllGather: opCount 0 sendbuff 0x7f7820400800 recvbuff 0x7f7820400e00 count 8 datatype 0 op 0 root 0 comm 0x872216e0 [nranks=2] stream 0x8700f7f0
lyra:289865:289865 [1] NCCL INFO misc/utils.cc:235 memory stack hunk malloc(65536)
[E ProcessGroupNCCL.cpp:828] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLGATHER, Timeout(ms)=10000) ran for 17323 milliseconds before timing out.
lyra:289865:289905 [1] NCCL INFO [Service thread] Connection closed by localRank 1
Traceback (most recent call last):
  File "/lyra-share/T0Projects2/n14.16-60only/testdist.py", line 31, in <module>
[E ProcessGroupNCCL.cpp:828] [Rank 0] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=1, OpType=ALLGATHER, Timeout(ms)=10000) ran for 17333 milliseconds before timing out.
    ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
  File "/home/zj/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/home/zj/anaconda3/lib/python3.9/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
RuntimeError: DDP expects same model across all ranks, but Rank 1 has 4 params, while rank 0 has inconsistent 0 params.
lyra:289865:289871 [0] NCCL INFO comm 0x872216e0 rank 1 nranks 2 cudaDev 1 busId 24000 - Abort COMPLETE
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
lyra:289864:289906 [0] NCCL INFO [Service thread] Connection closed by localRank 0
Traceback (most recent call last):
  File "/lyra-share/T0Projects2/n14.16-60only/testdist.py", line 31, in <module>
    ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
  File "/home/zj/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 674, in __init__
lyra:289864:289874 [0] NCCL INFO comm 0x580c6ef0 rank 0 nranks 2 cudaDev 0 busId 1000 - Abort COMPLETE
[E ProcessGroupNCCL.cpp:455] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data.
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/home/zj/anaconda3/lib/python3.9/site-packages/torch/distributed/utils.py", line 118, in _verify_param_shape_across_processes
[E ProcessGroupNCCL.cpp:460] To avoid data inconsistency, we are taking the entire process down.
    return dist._verify_params_across_processes(process_group, tensors, logger)
RuntimeError: DDP expects same model across all ranks, but Rank 0 has 4 params, while rank 1 has inconsistent 0 params.
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 0 (pid: 289864) of binary: /home/zj/anaconda3/bin/python
Traceback (most recent call last):
  File "/home/zj/anaconda3/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/home/zj/anaconda3/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/home/zj/anaconda3/lib/python3.9/site-packages/torch/distributed/run.py", line 794, in main
    run(args)
  File "/home/zj/anaconda3/lib/python3.9/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/home/zj/anaconda3/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/zj/anaconda3/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
=======================================================
testdist.py FAILED
-------------------------------------------------------
Failures:
[1]:
  time      : 2023-05-25_08:43:45
  host      : lyra
  rank      : 1 (local_rank: 1)
  exitcode  : -6 (pid: 289865)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 289865
-------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-05-25_08:43:45
  host      : lyra
  rank      : 0 (local_rank: 0)
  exitcode  : -6 (pid: 289864)
  error_file: <N/A>
  traceback : Signal 6 (SIGABRT) received by PID 289864
=======================================================

My observations:

  1. Reboot does not help.
  2. The script works before yesterday
  3. The script works if I switch to gloo backend
  4. The script works on my another machine (nccl backend)

I suspect that some NCCL-related hardware failed. Can anyone help me fignure out where the problem is? Thanks.

@mrshenli You are right. dist.broadcast hangs too. Even the simplest dist.barrier() will hang.

1 Like