NCCL failing with A100 GPUs, works fine with V100 GPUs

Hello everyone! I tried solving this issue on my own but after a few days of trying to do so I have to concede… Admittedly, I am no expert when it comes to Linux in general and this is my first time working in a high performance computing environment.

Although I was able to utilise DDP with NCCL in the past in order to train my models, I noticed a few days ago that I would get weird errors (something along the lines of size could not be broadcast) which I did not get when training my models a month ago. I did change PyTorch and Python versions since then, which is why I wanted to eliminate as many variables as possible and decided to work on a toy example taken from the DDP guide.

The following is the code I use:

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from torch.nn.parallel import DistributedDataParallel as DDP


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic():
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    print(f"Start running basic DDP example on rank {rank}.")

    # create model and move it to GPU with id rank
    device_id = rank % torch.cuda.device_count()
    model = ToyModel().to(device_id)
    print("Initialising DDP")
    ddp_model = DDP(model, device_ids=[device_id])
    print("Initialised DDP")

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(device_id)
    loss_fn(outputs, labels).backward()
    optimizer.step()
    dist.destroy_process_group()


if __name__ == "__main__":
    print(f"Running PyTorch {torch.__version__} with CUDA {torch.version.cuda} and NCCL {torch.cuda.nccl.version()}")
    demo_basic()

which I run using this command:

torchrun --nnodes 2 --nproc_per_node 1 --rdzv-backend c10d --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT path/to/script.py

The script works perfectly fine when I run it on two nodes with one V100 GPU each using NCCL or when using the Gloo backend with A100 GPUs, but I just cannot get it to work using NCCL and two nodes with one (or more, this is just for debugging purposes) A100 each. This is the error log I am receiving:

Running PyTorch 2.2.2+cu121 with CUDA 12.1 and NCCL (2, 19, 3)
Start running basic DDP example on rank 1.
Initialising DDP
hpc-g4-1:52582:52582 [0] NCCL INFO cudaDriverVersion 12020
hpc-g4-1:52582:52582 [0] NCCL INFO Bootstrap : Using eth0:123.45.67.89<0>
hpc-g4-1:52582:52582 [0] NCCL INFO NET/Plugin : dlerror=libnccl-net.so: cannot open shared object file: No such file or directory No plugin found (libnccl-net.so), using internal implementation
hpc-g4-1:52582:52592 [0] NCCL INFO NET/IB : Using [0]mlx5_0:1/RoCE ; OOB eth0:123.45.67.89<0>
hpc-g4-1:52582:52592 [0] NCCL INFO Using non-device net plugin version 0
hpc-g4-1:52582:52592 [0] NCCL INFO Using network IB
hpc-g4-1:52582:52592 [0] NCCL INFO comm 0x8aef670 rank 1 nranks 2 cudaDev 0 nvmlDev 0 busId 60 commId 0xe200dbb906c36c4f - Init START
hpc-g4-1:52582:52592 [0] NCCL INFO Trees [0] -1/-1/-1->1->0 [1] 0/-1/-1->1->-1
hpc-g4-1:52582:52592 [0] NCCL INFO P2P Chunksize set to 131072
hpc-g4-1:52582:52592 [0] NCCL INFO Channel 00/0 : 0[0] -> 1[0] [receive] via NET/IB/0
hpc-g4-1:52582:52592 [0] NCCL INFO Channel 01/0 : 0[0] -> 1[0] [receive] via NET/IB/0
hpc-g4-1:52582:52592 [0] NCCL INFO Channel 00/0 : 1[0] -> 0[0] [send] via NET/IB/0
hpc-g4-1:52582:52592 [0] NCCL INFO Channel 01/0 : 1[0] -> 0[0] [send] via NET/IB/0

hpc-g4-1:52582:52594 [0] misc/ibvwrap.cc:190 NCCL WARN Call to ibv_create_cq failed with error Cannot allocate memory
hpc-g4-1:52582:52594 [0] NCCL INFO transport/net_ib.cc:520 -> 2
hpc-g4-1:52582:52594 [0] NCCL INFO transport/net_ib.cc:647 -> 2
hpc-g4-1:52582:52594 [0] NCCL INFO transport/net.cc:677 -> 2
hpc-g4-1:52582:52592 [0] NCCL INFO transport/net.cc:304 -> 2
hpc-g4-1:52582:52592 [0] NCCL INFO transport.cc:148 -> 2
hpc-g4-1:52582:52592 [0] NCCL INFO init.cc:1117 -> 2
hpc-g4-1:52582:52592 [0] NCCL INFO init.cc:1396 -> 2
hpc-g4-1:52582:52592 [0] NCCL INFO group.cc:64 -> 2 [Async thread]
hpc-g4-1:52582:52582 [0] NCCL INFO group.cc:418 -> 2
hpc-g4-1:52582:52582 [0] NCCL INFO group.cc:95 -> 2

hpc-g4-1:52582:52594 [0] misc/ibvwrap.cc:190 NCCL WARN Call to ibv_create_cq failed with error Cannot allocate memory
hpc-g4-1:52582:52594 [0] NCCL INFO transport/net_ib.cc:520 -> 2
hpc-g4-1:52582:52594 [0] NCCL INFO transport/net_ib.cc:647 -> 2
hpc-g4-1:52582:52594 [0] NCCL INFO transport/net.cc:677 -> 2
hpc-g4-1:52582:52594 [0] NCCL INFO misc/socket.cc:47 -> 3
hpc-g4-1:52582:52594 [0] NCCL INFO misc/socket.cc:58 -> 3
hpc-g4-1:52582:52594 [0] NCCL INFO misc/socket.cc:773 -> 3
hpc-g4-1:52582:52594 [0] NCCL INFO proxy.cc:1374 -> 3
hpc-g4-1:52582:52594 [0] NCCL INFO proxy.cc:1415 -> 3

hpc-g4-1:52582:52594 [0] proxy.cc:1557 NCCL WARN [Proxy Service 1] Failed to execute operation Connect from rank 1, retcode 3
Traceback (most recent call last):
  File "/users/felix.schoen/data/projects/Project/./meta/playground/ddp/ddp.py", line 45, in <module>
    demo_basic()
  File "/users/felix.schoen/data/projects/Project/./meta/playground/ddp/ddp.py", line 29, in demo_basic
    ddp_model = DDP(model, device_ids=[device_id])
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/felix.schoen/data/projects/Project/venv/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 798, in __init__
    _verify_param_shape_across_processes(self.process_group, parameters)
  File "/users/felix.schoen/data/projects/Project/venv/lib/python3.11/site-packages/torch/distributed/utils.py", line 263, in _verify_param_shape_across_processes
    return dist._verify_params_across_processes(process_group, tensors, logger)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1691, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.19.3
ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. 
Last error:
Call to ibv_create_cq failed with error Cannot allocate memory
hpc-g4-1:52582:52582 [0] NCCL INFO comm 0x8aef670 rank 1 nranks 2 cudaDev 0 busId 60 - Abort COMPLETE
[2024-04-23 00:11:00,951] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 52582) of binary: /users/felix.schoen/data/projects/Project/venv/bin/python3.11
Traceback (most recent call last):
  File "/users/felix.schoen/data/projects/Project/venv/bin/torchrun", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/users/felix.schoen/data/projects/Project/venv/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/users/felix.schoen/data/projects/Project/venv/lib/python3.11/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/users/felix.schoen/data/projects/Project/venv/lib/python3.11/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/users/felix.schoen/data/projects/Project/venv/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/users/felix.schoen/data/projects/Project/venv/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
./meta/playground/ddp/ddp.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-04-23_00:11:00
  host      : hpc-g4-1.domain.com
  rank      : 1 (local_rank: 0)
  exitcode  : 1 (pid: 52582)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

ulimit -l is already set to unlimited. According to the official documentation NCCL 2.19.3 (the version PyTorch ships with) doesn’t even support CUDA 12.1, although I find that hard to believe. According to nvidia-smi the current driver version installed is 535.129.03 with support for CUDA 12.2.

I’d be grateful for pointers on how to solve this, thanks in advance!

That’s indeed not the case and NCCL’s release notes mention their binary builds and the corresponding CUDA runtime version. We are building PyTorch with NCCL==2.19.3+CUDA12.1 for a long time already and don’t see any issues.

Is indeed the issue and you could try to run standalone NCCL tests to see if the issue is reproducible there or is related to PyTorch. In the former case, you might want to create an issue in the NCCL repository (and CC me there so that I can also track it).

Thanks for the quick reply! I ran the NCCL tests on the HPC system with one caveat: It currently only has NCCL modules up to version 2.18.3 with CUDA 12.1.1, which is what I used to run the tests.

I ran them on all four types of GPUs, these being P100s, V100s, Quadro RTX 6000s and A100s. Interestingly the tests all seemed to pass. Here is the output for the setup with two nodes with one V100 GPU each for example:

# nThread 1 nGpus 1 minBytes 8 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0
#
# Using devices
# nThread 1 nGpus 1 minBytes 8 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0
#
# Using devices
#  Rank  0 Group  0 Pid  36665 on  clip-g2-3 device  0 [0x00] Tesla V100-PCIE-32GB
#  Rank  0 Group  0 Pid  43906 on  clip-g2-2 device  0 [0x00] Tesla V100-PCIE-32GB
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
           8             2     float     sum      -1     3.49    0.00    0.00      0     0.12    0.07    0.00      0
          16             4     float     sum      -1     4.73    0.00    0.00      0     0.12    0.13    0.00      0
          32             8     float     sum      -1     4.52    0.01    0.00      0     0.12    0.27    0.00      0
          64            16     float     sum      -1     4.68    0.01    0.00      0     0.12    0.54    0.00      0
         128            32     float     sum      -1     5.26    0.02    0.00      0     0.12    1.05    0.00      0
         256            64     float     sum      -1     4.43    0.06    0.00      0     0.12    2.14    0.00      0
         512           128     float     sum      -1     4.50    0.11    0.00      0     0.12    4.30    0.00      0
        1024           256     float     sum      -1     4.40    0.23    0.00      0     0.12    8.61    0.00      0
        2048           512     float     sum      -1     4.23    0.48    0.00      0     0.12   17.20    0.00      0
        4096          1024     float     sum      -1     4.33    0.94    0.00      0     0.12   34.57    0.00      0
        8192          2048     float     sum      -1     4.56    1.80    0.00      0     0.12   68.55    0.00      0
       16384          4096     float     sum      -1     4.51    3.63    0.00      0     0.12  136.59    0.00      0
       32768          8192     float     sum      -1     4.98    6.58    0.00      0     0.12  272.95    0.00      0
       65536         16384     float     sum      -1     5.02   13.06    0.00      0     0.12  557.28    0.00      0
      131072         32768     float     sum      -1     4.60   28.48    0.00      0     0.12  1093.18    0.00      0
      262144         65536     float     sum      -1     4.45   58.88    0.00      0     0.12  2253.06    0.00      0
      524288        131072     float     sum      -1     4.87  107.71    0.00      0     0.12  4490.69    0.00      0
     1048576        262144     float     sum      -1     6.02  174.28    0.00      0     0.12  8566.80    0.00      0
     2097152        524288     float     sum      -1     8.46  247.96    0.00      0     0.12  17331.83    0.00      0
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
     4194304       1048576     float     sum      -1    14.71  285.20    0.00      0     0.12  33866.00    0.00      0
           8             2     float     sum      -1     3.95    0.00    0.00      0     0.13    0.06    0.00      0
          16             4     float     sum      -1     5.45    0.00    0.00      0     0.13    0.13    0.00      0
          32             8     float     sum      -1     4.64    0.01    0.00      0     0.13    0.25    0.00      0
          64            16     float     sum      -1     4.35    0.01    0.00      0     0.13    0.51    0.00      0
         128            32     float     sum      -1     4.26    0.03    0.00      0     0.12    1.03    0.00      0
     8388608       2097152     float     sum      -1    25.69  326.50    0.00      0     0.12  67189.49    0.00      0
         256            64     float     sum      -1     4.26    0.06    0.00      0     0.12    2.06    0.00      0
         512           128     float     sum      -1     4.28    0.12    0.00      0     0.13    4.04    0.00      0
        1024           256     float     sum      -1     4.12    0.25    0.00      0     0.13    7.97    0.00      0
        2048           512     float     sum      -1     4.06    0.50    0.00      0     0.13   16.31    0.00      0
        4096          1024     float     sum      -1     4.27    0.96    0.00      0     0.13   32.44    0.00      0
        8192          2048     float     sum      -1     3.97    2.06    0.00      0     0.12   65.56    0.00      0
       16384          4096     float     sum      -1     4.18    3.92    0.00      0     0.12  131.44    0.00      0
       32768          8192     float     sum      -1     4.51    7.27    0.00      0     0.13  258.73    0.00      0
       65536         16384     float     sum      -1     4.35   15.06    0.00      0     0.13  517.46    0.00      0
    16777216       4194304     float     sum      -1    47.19  355.52    0.00      0     0.13  126334.46    0.00      0
      131072         32768     float     sum      -1     4.20   31.21    0.00      0     0.38  344.25    0.00      0
    33554432       8388608     float     sum      -1    90.14  372.26    0.00      0     0.13  261123.98    0.00      0
      262144         65536     float     sum      -1     4.24   61.88    0.00      0     0.13  2090.46    0.00      0
      524288        131072     float     sum      -1     4.55  115.17    0.00      0     0.12  4246.97    0.00      0
    67108864      16777216     float     sum      -1    176.1  381.12    0.00      0     0.13  530924.56    0.00      0
     1048576        262144     float     sum      -1     6.45  162.54    0.00      0     0.12  8425.68    0.00      0
   134217728      33554432     float     sum      -1    348.2  385.44    0.00      0     0.12  1080223.16    0.00      0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0 
#

     2097152        524288     float     sum      -1     9.41  222.83    0.00      0     0.13  16670.52    0.00      0
     4194304       1048576     float     sum      -1    14.64  286.41    0.00      0     0.13  33235.37    0.00      0
     8388608       2097152     float     sum      -1    26.00  322.62    0.00      0     0.13  65204.88    0.00      0
    16777216       4194304     float     sum      -1    47.17  355.68    0.00      0     0.13  133896.38    0.00      0
    33554432       8388608     float     sum      -1    90.09  372.46    0.00      0     0.13  265777.68    0.00      0
    67108864      16777216     float     sum      -1    176.0  381.28    0.00      0     0.13  531134.66    0.00      0
   134217728      33554432     float     sum      -1    348.0  385.63    0.00      0     0.13  1061849.11    0.00      0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0 
#

In contrast, this is the output of the first run on two nodes with one A100 each:

# nThread 1 nGpus 1 minBytes 8 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0
#
# Using devices
# nThread 1 nGpus 1 minBytes 8 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0
#
# Using devices
#  Rank  0 Group  0 Pid   9692 on clip-g4-11 device  0 [0xb1] NVIDIA A100-SXM4-40GB
#  Rank  0 Group  0 Pid  10235 on  clip-g4-9 device  0 [0x17] NVIDIA A100-SXM4-40GB
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
           8             2     float     sum      -1     6.62    0.00    0.00      0     0.27    0.03    0.00      0
          16             4     float     sum      -1     6.46    0.00    0.00      0     0.27    0.06    0.00      0
          32             8     float     sum      -1     6.46    0.00    0.00      0     0.28    0.11    0.00      0
          64            16     float     sum      -1     6.39    0.01    0.00      0     0.27    0.24    0.00      0
         128            32     float     sum      -1     6.40    0.02    0.00      0     0.27    0.48    0.00      0
         256            64     float     sum      -1     6.41    0.04    0.00      0     0.27    0.96    0.00      0
         512           128     float     sum      -1     6.41    0.08    0.00      0     0.27    1.92    0.00      0
           8             2     float     sum      -1     6.06    0.00    0.00      0     0.24    0.03    0.00      0
        1024           256     float     sum      -1     6.43    0.16    0.00      0     0.20    5.24    0.00      0
          16             4     float     sum      -1     5.95    0.00    0.00      0     0.25    0.06    0.00      0
        2048           512     float     sum      -1     4.88    0.42    0.00      0     0.20   10.36    0.00      0
          32             8     float     sum      -1     5.92    0.01    0.00      0     0.25    0.13    0.00      0
        4096          1024     float     sum      -1     5.02    0.82    0.00      0     0.19   21.21    0.00      0
          64            16     float     sum      -1     5.09    0.01    0.00      0     0.18    0.36    0.00      0
        8192          2048     float     sum      -1     5.18    1.58    0.00      0     0.20   41.82    0.00      0
         128            32     float     sum      -1     4.44    0.03    0.00      0     0.18    0.73    0.00      0
         256            64     float     sum      -1     4.57    0.06    0.00      0     0.17    1.47    0.00      0
       16384          4096     float     sum      -1     4.91    3.34    0.00      0     0.22   73.95    0.00      0
         512           128     float     sum      -1     4.35    0.12    0.00      0     0.17    3.01    0.00      0
        1024           256     float     sum      -1     4.44    0.23    0.00      0     0.17    5.98    0.00      0
        2048           512     float     sum      -1     4.42    0.46    0.00      0     0.17   11.91    0.00      0
       32768          8192     float     sum      -1     5.01    6.54    0.00      0     0.22  148.41    0.00      0
        4096          1024     float     sum      -1     4.44    0.92    0.00      0     0.17   24.04    0.00      0
        8192          2048     float     sum      -1     4.43    1.85    0.00      0     0.18   46.37    0.00      0
       16384          4096     float     sum      -1     4.41    3.72    0.00      0     0.20   82.04    0.00      0
       32768          8192     float     sum      -1     4.47    7.33    0.00      0     0.13  247.59    0.00      0
       65536         16384     float     sum      -1     5.02   13.06    0.00      0     0.15  449.03    0.00      0
       65536         16384     float     sum      -1     4.05   16.17    0.00      0     0.14  482.41    0.00      0
      131072         32768     float     sum      -1     4.23   30.99    0.00      0     0.11  1147.74    0.00      0
      131072         32768     float     sum      -1     4.13   31.70    0.00      0     0.12  1094.09    0.00      0
      262144         65536     float     sum      -1     4.32   60.70    0.00      0     0.12  2205.67    0.00      0
      262144         65536     float     sum      -1     4.62   56.79    0.00      0     0.12  2270.63    0.00      0
      524288        131072     float     sum      -1     4.57  114.66    0.00      0     0.12  4484.93    0.00      0
      524288        131072     float     sum      -1     4.63  113.23    0.00      0     0.12  4508.07    0.00      0
     1048576        262144     float     sum      -1     6.29  166.79    0.00      0     0.12  9031.66    0.00      0
     1048576        262144     float     sum      -1     8.48  123.64    0.00      0     0.13  7983.07    0.00      0
     2097152        524288     float     sum      -1     8.09  259.22    0.00      0     0.12  17571.45    0.00      0
     2097152        524288     float     sum      -1     8.27  253.71    0.00      0     0.12  18055.55    0.00      0
     4194304       1048576     float     sum      -1    11.92  352.00    0.00      0     0.12  34620.75    0.00      0
     4194304       1048576     float     sum      -1    11.87  353.31    0.00      0     0.12  36157.79    0.00      0
     8388608       2097152     float     sum      -1    16.63  504.48    0.00      0     0.12  70879.66    0.00      0
     8388608       2097152     float     sum      -1    18.02  465.54    0.00      0     0.11  74698.20    0.00      0
    16777216       4194304     float     sum      -1    29.07  577.15    0.00      0     0.11  146079.37    0.00      0
    16777216       4194304     float     sum      -1    30.31  553.53    0.00      0     0.11  148274.11    0.00      0
    33554432       8388608     float     sum      -1    54.78  612.53    0.00      0     0.12  289262.34    0.00      0
    33554432       8388608     float     sum      -1    54.70  613.38    0.00      0     0.11  297864.47    0.00      0
    67108864      16777216     float     sum      -1    105.6  635.68    0.00      0     0.12  567037.30    0.00      0
    67108864      16777216     float     sum      -1    105.3  637.47    0.00      0     0.11  596788.47    0.00      0
   134217728      33554432     float     sum      -1    206.4  650.41    0.00      0     0.11  1167110.68    0.00      0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0 
#

   134217728      33554432     float     sum      -1    206.6  649.63    0.00      0     0.12  1148632.67    0.00      0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0 
#

Here, the formatting confused me with the first “out-of-place” block being empty, when running it again this issue was resolved though:

# nThread 1 nGpus 1 minBytes 8 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0
#
# Using devices
# nThread 1 nGpus 1 minBytes 8 maxBytes 134217728 step: 2(factor) warmup iters: 5 iters: 20 agg iters: 1 validation: 1 graph: 0
#
# Using devices
#  Rank  0 Group  0 Pid  42962 on  clip-g4-3 device  0 [0x31] NVIDIA A100-SXM4-40GB
#  Rank  0 Group  0 Pid  17276 on  clip-g4-4 device  0 [0xca] NVIDIA A100-SXM4-40GB
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
           8             2     float     sum      -1     5.61    0.00    0.00      0     0.17    0.05    0.00      0
          16             4     float     sum      -1     4.51    0.00    0.00      0     0.18    0.09    0.00      0
          32             8     float     sum      -1     4.50    0.01    0.00      0     0.17    0.18    0.00      0
          64            16     float     sum      -1     4.48    0.01    0.00      0     0.17    0.38    0.00      0
         128            32     float     sum      -1     4.50    0.03    0.00      0     0.17    0.75    0.00      0
         256            64     float     sum      -1     4.64    0.06    0.00      0     0.17    1.49    0.00      0
         512           128     float     sum      -1     4.45    0.12    0.00      0     0.17    3.02    0.00      0
        1024           256     float     sum      -1     4.65    0.22    0.00      0     0.17    6.08    0.00      0
        2048           512     float     sum      -1     4.44    0.46    0.00      0     0.17   12.08    0.00      0
        4096          1024     float     sum      -1     4.47    0.92    0.00      0     0.11   35.76    0.00      0
        8192          2048     float     sum      -1     3.76    2.18    0.00      0     0.12   70.08    0.00      0
       16384          4096     float     sum      -1     3.88    4.22    0.00      0     0.13  121.41    0.00      0
       32768          8192     float     sum      -1     3.95    8.29    0.00      0     0.13  247.87    0.00      0
       65536         16384     float     sum      -1     3.92   16.72    0.00      0     0.12  555.63    0.00      0
      131072         32768     float     sum      -1     4.01   32.66    0.00      0     0.12  1128.47    0.00      0
      262144         65536     float     sum      -1     4.25   61.75    0.00      0     0.11  2302.54    0.00      0
      524288        131072     float     sum      -1     4.49  116.84    0.00      0     0.12  4502.26    0.00      0
#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
     1048576        262144     float     sum      -1     5.78  181.49    0.00      0     0.11  9185.95    0.00      0
           8             2     float     sum      -1     7.97    0.00    0.00      0     0.34    0.02    0.00      0
          16             4     float     sum      -1     9.78    0.00    0.00      0     0.33    0.05    0.00      0
          32             8     float     sum      -1     8.95    0.00    0.00      0     0.32    0.10    0.00      0
          64            16     float     sum      -1     9.39    0.01    0.00      0     0.32    0.20    0.00      0
     2097152        524288     float     sum      -1     7.63  274.92    0.00      0     0.11  18509.73    0.00      0
         128            32     float     sum      -1     8.77    0.01    0.00      0     0.33    0.39    0.00      0
         256            64     float     sum      -1     8.76    0.03    0.00      0     0.33    0.78    0.00      0
         512           128     float     sum      -1     8.39    0.06    0.00      0     0.33    1.56    0.00      0
        1024           256     float     sum      -1     7.11    0.14    0.00      0     0.21    4.95    0.00      0
        2048           512     float     sum      -1     6.10    0.34    0.00      0     0.21    9.73    0.00      0
        4096          1024     float     sum      -1     6.08    0.67    0.00      0     0.21   19.60    0.00      0
        8192          2048     float     sum      -1     6.03    1.36    0.00      0     0.21   39.55    0.00      0
       16384          4096     float     sum      -1     6.22    2.63    0.00      0     0.24   68.80    0.00      0
       32768          8192     float     sum      -1     6.22    5.27    0.00      0     0.25  133.28    0.00      0
       65536         16384     float     sum      -1     6.10   10.75    0.00      0     0.17  376.00    0.00      0
     4194304       1048576     float     sum      -1    11.47  365.79    0.00      0     0.11  36808.28    0.00      0
      131072         32768     float     sum      -1     4.90   26.75    0.00      0     0.12  1121.71    0.00      0
     8388608       2097152     float     sum      -1    17.45  480.85    0.00      0     0.11  73941.01    0.00      0
      262144         65536     float     sum      -1     4.05   64.74    0.00      0     0.11  2285.48    0.00      0
    16777216       4194304     float     sum      -1    29.65  565.77    0.00      0     0.11  149130.81    0.00      0
      524288        131072     float     sum      -1     4.40  119.21    0.00      0     0.12  4462.03    0.00      0
    33554432       8388608     float     sum      -1    55.45  605.08    0.00      0     0.11  293436.22    0.00      0
     1048576        262144     float     sum      -1     6.02  174.06    0.00      0     0.12  8996.79    0.00      0
    67108864      16777216     float     sum      -1    106.3  631.37    0.00      0     0.11  589449.84    0.00      0
     2097152        524288     float     sum      -1     7.82  268.08    0.00      0     0.12  17848.10    0.00      0
     4194304       1048576     float     sum      -1    11.73  357.69    0.00      0     0.11  36647.48    0.00      0
   134217728      33554432     float     sum      -1    207.1  648.05    0.00      0     0.11  1183056.22    0.00      0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0 
#

     8388608       2097152     float     sum      -1    16.62  504.81    0.00      0     0.11  74400.07    0.00      0
    16777216       4194304     float     sum      -1    28.92  580.14    0.00      0     0.11  146206.68    0.00      0
    33554432       8388608     float     sum      -1    54.77  612.61    0.00      0     0.12  284600.78    0.00      0
    67108864      16777216     float     sum      -1    105.3  637.11    0.00      0     0.11  592572.75    0.00      0
   134217728      33554432     float     sum      -1    206.2  650.83    0.00      0     0.11  1167618.34    0.00      0
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 0 
#

To me it seems like NCCL is working just fine but that there could be a configuration problem with the HPC system. Do you have any further pointers on how to fix these issues? I plan on contacting the administrators of the system in order to get this resolved and I’d like to be able to provide them with as much information as possible!

On an unrelated note: I really appreciate all your thousands of answers, they were really helpful in the past!

I am facing same issue with NCCL version 2.21.5+cuda12.4, any solution?

#
#                                                              out-of-place                       in-place          
#       size         count      type   redop    root     time   algbw   busbw #wrong     time   algbw   busbw #wrong
#        (B)    (elements)                               (us)  (GB/s)  (GB/s)            (us)  (GB/s)  (GB/s)       
           8             2     float     sum      -1    17.66    0.00    0.00      0    17.01    0.00    0.00      0
          16             4     float     sum      -1    17.77    0.00    0.00      0    17.55    0.00    0.00      0
          32             8     float     sum      -1    17.48    0.00    0.00      0    17.31    0.00    0.00      0
          64            16     float     sum      -1    17.42    0.00    0.01      0    17.25    0.00    0.01      0
         128            32     float     sum      -1    17.30    0.01    0.01      0    17.70    0.01    0.01      0
         256            64     float     sum      -1    17.53    0.01    0.02      0    17.54    0.01    0.02      0
         512           128     float     sum      -1    18.38    0.03    0.04      0    18.33    0.03    0.04      0
        1024           256     float     sum      -1    19.73    0.05    0.08      0    19.16    0.05    0.08      0
        2048           512     float     sum      -1    20.54    0.10    0.15      0    20.23    0.10    0.15      0
        4096          1024     float     sum      -1    20.20    0.20    0.30      0    37.26    0.11    0.16      0
        8192          2048     float     sum      -1    22.71    0.36    0.54      0    22.38    0.37    0.55      0
       16384          4096     float     sum      -1    24.40    0.67    1.01      0    23.54    0.70    1.04      0
       32768          8192     float     sum      -1    27.88    1.18    1.76      0    27.94    1.17    1.76      0
       65536         16384     float     sum      -1    28.98    2.26    3.39      0    29.89    2.19    3.29      0
      131072         32768     float     sum      -1    40.57    3.23    4.85      0    38.08    3.44    5.16      0
      262144         65536     float     sum      -1    53.12    4.93    7.40      0    53.78    4.87    7.31      0
      524288        131072     float     sum      -1    83.01    6.32    9.47      0    82.24    6.37    9.56      0
     1048576        262144     float     sum      -1    56.36   18.61   27.91      0    55.67   18.83   28.25      0
     2097152        524288     float     sum      -1    70.52   29.74   44.61      0    69.63   30.12   45.18      0
     4194304       1048576     float     sum      -1    97.88   42.85   64.28      0    97.40   43.06   64.59      0
     8388608       2097152     float     sum      -1    168.6   49.75   74.63      0    168.0   49.94   74.92      0
    16777216       4194304     float     sum      -1    313.2   53.57   80.35      0    313.1   53.58   80.38      0
    33554432       8388608     float     sum      -1    586.5   57.21   85.82      0    586.8   57.18   85.77      0
    67108864      16777216     float     sum      -1   1099.2   61.05   91.58      0   1099.2   61.05   91.58      0
   134217728      33554432     float     sum      -1   2258.4   59.43   89.15      0   2249.0   59.68   89.52      0
   268435456      67108864     float     sum      -1   4153.1   64.63   96.95      0   4155.0   64.61   96.91      0
   536870912     134217728     float     sum      -1   5995.8   89.54  134.31      0   5989.7   89.63  134.45      0
  1073741824     268435456     float     sum      -1    11861   90.53  135.79      0    11842   90.68  136.01      0
  2147483648     536870912     float     sum      -1    23548   91.20  136.79      0    23518   91.31  136.97      0
  4294967296    1073741824     float     sum      -1    47009   91.37  137.05      0    46983   91.42  137.12      0
  8589934592    2147483648     float     sum      -1    93890   91.49  137.23      0    93634   91.74  137.61      0
farm22-gpu0304:22511:22511 [1] NCCL INFO comm 0x55a226d7b990 rank 3 nranks 4 cudaDev 1 busId 3b000 - Destroy COMPLETE
farm22-gpu0303:21358:21358 [1] NCCL INFO comm 0x55c386f0d190 rank 1 nranks 4 cudaDev 1 busId 3b000 - Destroy COMPLETE
farm22-gpu0304:22509:22509 [0] NCCL INFO comm 0x5629410833b0 rank 2 nranks 4 cudaDev 0 busId 19000 - Destroy COMPLETE
farm22-gpu0303:21357:21357 [0] NCCL INFO comm 0x563416ecc320 rank 0 nranks 4 cudaDev 0 busId 19000 - Destroy COMPLETE
# Out of bounds values : 0 OK
# Avg bus bandwidth    : 44.0951 
#

I don’t know what your error is since you didn’t post any but could check this post explaining how to increase the pinned memory limit in case it’s too low.

thank you for the prompt reply

[rank4]:[I1219 14:16:17.711463201 ProcessGroupWrapper.cpp:587] [Rank 4] Running collective: CollectiveFingerPrint(SequenceNumber=0, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank5]:[I1219 14:16:17.711465589 ProcessGroupWrapper.cpp:587] [Rank 5] Running collective: CollectiveFingerPrint(SequenceNumber=0, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank7]:[I1219 14:16:17.721750357 ProcessGroupWrapper.cpp:587] [Rank 7] Running collective: CollectiveFingerPrint(SequenceNumber=0, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank6]:[I1219 14:16:17.721762844 ProcessGroupWrapper.cpp:587] [Rank 6] Running collective: CollectiveFingerPrint(SequenceNumber=0, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank3]:[I1219 14:16:17.765543226 ProcessGroupWrapper.cpp:587] [Rank 3] Running collective: CollectiveFingerPrint(SequenceNumber=0, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank2]:[I1219 14:16:17.765543174 ProcessGroupWrapper.cpp:587] [Rank 2] Running collective: CollectiveFingerPrint(SequenceNumber=0, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank0]:[I1219 14:16:17.765542362 ProcessGroupWrapper.cpp:587] [Rank 0] Running collective: CollectiveFingerPrint(SequenceNumber=0, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank1]:[I1219 14:16:17.765540404 ProcessGroupWrapper.cpp:587] [Rank 1] Running collective: CollectiveFingerPrint(SequenceNumber=0, OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt)))
[rank0]:[I1219 14:16:17.836304813 ProcessGroupNCCL.cpp:2262] [PG ID 0 PG GUID 0 Rank 0] ProcessGroupNCCL broadcast unique ID through store took 0.039207 ms
[rank3]:[I1219 14:16:17.836386623 ProcessGroupNCCL.cpp:2262] [PG ID 0 PG GUID 0 Rank 3] ProcessGroupNCCL broadcast unique ID through store took 56.5557 ms
[rank1]:[I1219 14:16:17.836409304 ProcessGroupNCCL.cpp:2262] [PG ID 0 PG GUID 0 Rank 1] ProcessGroupNCCL broadcast unique ID through store took 56.577 ms
[rank2]:[I1219 14:16:17.836435087 ProcessGroupNCCL.cpp:2262] [PG ID 0 PG GUID 0 Rank 2] ProcessGroupNCCL broadcast unique ID through store took 56.5792 ms
[rank4]:[I1219 14:16:17.880634490 ProcessGroupNCCL.cpp:2262] [PG ID 0 PG GUID 0 Rank 4] ProcessGroupNCCL broadcast unique ID through store took 57.0495 ms
[rank7]:[I1219 14:16:17.880656896 ProcessGroupNCCL.cpp:2262] [PG ID 0 PG GUID 0 Rank 7] ProcessGroupNCCL broadcast unique ID through store took 57.0765 ms
[rank6]:[I1219 14:16:17.880662017 ProcessGroupNCCL.cpp:2262] [PG ID 0 PG GUID 0 Rank 6] ProcessGroupNCCL broadcast unique ID through store took 57.0812 ms
[rank5]:[I1219 14:16:17.880673318 ProcessGroupNCCL.cpp:2262] [PG ID 0 PG GUID 0 Rank 5] ProcessGroupNCCL broadcast unique ID through store took 57.0881 ms
[rank1]: Traceback (most recent call last):
[rank1]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 1017, in <module>
[rank1]:     sys.exit(recipe_main())
[rank1]:              ^^^^^^^^^^^^^
[rank1]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank1]:     sys.exit(recipe_main(conf))
[rank1]:              ^^^^^^^^^^^^^^^^^
[rank1]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 1010, in recipe_main
[rank1]:     recipe = FullFinetuneRecipeDistributed(cfg=cfg)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 225, in __init__
[rank1]:     self.seed = training.set_seed(seed=cfg.seed)
[rank1]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/training/seed.py", line 53, in set_seed
[rank1]:     seed = _broadcast_tensor(rand_seed, 0).item()  # sync seed across ranks
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/training/_distributed.py", line 107, in _broadcast_tensor
[rank1]:     dist.broadcast(tensor, src=src, group=None)
[rank1]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2417, in broadcast
[rank1]:     work = default_pg.broadcast([tensor], opts)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/NCCLUtils.hpp:317, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.21.5
[rank1]: ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. 
[rank1]: Last error:
[rank1]: [Proxy Service 1] Failed to execute operation Connect from rank 1, retcode 3
[rank2]: Traceback (most recent call last):
[rank2]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 1017, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:              ^^^^^^^^^^^^^
[rank2]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:              ^^^^^^^^^^^^^^^^^
[rank2]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 1010, in recipe_main
[rank2]:     recipe = FullFinetuneRecipeDistributed(cfg=cfg)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 225, in __init__
[rank2]:     self.seed = training.set_seed(seed=cfg.seed)
[rank2]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/training/seed.py", line 53, in set_seed
[rank2]:     seed = _broadcast_tensor(rand_seed, 0).item()  # sync seed across ranks
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/training/_distributed.py", line 107, in _broadcast_tensor
[rank2]:     dist.broadcast(tensor, src=src, group=None)
[rank2]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2417, in broadcast
[rank2]:     work = default_pg.broadcast([tensor], opts)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/NCCLUtils.hpp:317, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.21.5
[rank2]: ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. 
[rank2]: Last error:
[rank2]: [Proxy Service 2] Failed to execute operation Connect from rank 2, retcode 3
[rank5]: Traceback (most recent call last):
[rank5]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 1017, in <module>
[rank5]:     sys.exit(recipe_main())
[rank5]:              ^^^^^^^^^^^^^
[rank5]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank5]:     sys.exit(recipe_main(conf))
[rank5]:              ^^^^^^^^^^^^^^^^^
[rank5]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 1010, in recipe_main
[rank5]:     recipe = FullFinetuneRecipeDistributed(cfg=cfg)
[rank5]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 225, in __init__
[rank5]:     self.seed = training.set_seed(seed=cfg.seed)
[rank5]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/training/seed.py", line 53, in set_seed
[rank5]:     seed = _broadcast_tensor(rand_seed, 0).item()  # sync seed across ranks
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/training/_distributed.py", line 107, in _broadcast_tensor
[rank5]:     dist.broadcast(tensor, src=src, group=None)
[rank5]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank5]:     return func(*args, **kwargs)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^
[rank5]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2417, in broadcast
[rank5]:     work = default_pg.broadcast([tensor], opts)
[rank5]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank5]: torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/NCCLUtils.hpp:317, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.21.5
[rank5]: ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. 
[rank5]: Last error:
[rank5]: [Proxy Service 5] Failed to execute operation Connect from rank 5, retcode 3
[rank6]: Traceback (most recent call last):
[rank6]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 1017, in <module>
[rank6]:     sys.exit(recipe_main())
[rank6]:              ^^^^^^^^^^^^^
[rank6]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank6]:     sys.exit(recipe_main(conf))
[rank6]:              ^^^^^^^^^^^^^^^^^
[rank6]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 1010, in recipe_main
[rank6]:     recipe = FullFinetuneRecipeDistributed(cfg=cfg)
[rank6]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/nfs/users/nfs_f/fg12/repos/mlflow-tutorial/torch_tune/src/full_test_dist.py", line 225, in __init__
[rank6]:     self.seed = training.set_seed(seed=cfg.seed)
[rank6]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/training/seed.py", line 53, in set_seed
[rank6]:     seed = _broadcast_tensor(rand_seed, 0).item()  # sync seed across ranks
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torchtune/training/_distributed.py", line 107, in _broadcast_tensor
[rank6]:     dist.broadcast(tensor, src=src, group=None)
[rank6]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank6]:     return func(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/software/isg/users/fg12/envs/virtualenvs/mlflow-torchtune-x8ZHjNqV-py3.11/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2417, in broadcast
[rank6]:     work = default_pg.broadcast([tensor], opts)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/NCCLUtils.hpp:317, unhandled system error (run with NCCL_DEBUG=INFO for details), NCCL version 2.21.5
[rank6]: ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. 
[rank6]: Last error:
[rank6]: [Proxy Service 6] Failed to execute operation Connect from rank 6, retcode 3
[rank1]:[I1219 14:16:20.322903302 ProcessGroupNCCL.cpp:1246] [PG ID 0 PG GUID 0 Rank 1] ProcessGroupNCCL destructor entered.
[rank1]:[I1219 14:16:20.322925171 ProcessGroupNCCL.cpp:1230] [PG ID 0 PG GUID 0 Rank 1] Launching ProcessGroupNCCL abort asynchrounously.
[rank2]:[I1219 14:16:20.322906064 ProcessGroupNCCL.cpp:1246] [PG ID 0 PG GUID 0 Rank 2] ProcessGroupNCCL destructor entered.
[rank2]:[I1219 14:16:20.322927464 ProcessGroupNCCL.cpp:1230] [PG ID 0 PG GUID 0 Rank 2] Launching ProcessGroupNCCL abort asynchrounously.
[rank2]:[I1219 14:16:20.323750881 ProcessGroupNCCL.cpp:1116] [PG ID 0 PG GUID 0 Rank 2] future is successfully executed for: ProcessGroup abort
[rank2]:[I1219 14:16:20.323759601 ProcessGroupNCCL.cpp:1237] [PG ID 0 PG GUID 0 Rank 2] ProcessGroupNCCL aborts successfully.
[rank1]:[I1219 14:16:20.323757341 ProcessGroupNCCL.cpp:1116] [PG ID 0 PG GUID 0 Rank 1] future is successfully executed for: ProcessGroup abort
[rank1]:[I1219 14:16:20.323760876 ProcessGroupNCCL.cpp:1237] [PG ID 0 PG GUID 0 Rank 1] ProcessGroupNCCL aborts successfully.
[rank2]:[I1219 14:16:20.323766013 ProcessGroupNCCL.cpp:1267] [PG ID 0 PG GUID 0 Rank 2] ProcessGroupNCCL watchdog thread joined.
[rank1]:[I1219 14:16:20.323766232 ProcessGroupNCCL.cpp:1267] [PG ID 0 PG GUID 0 Rank 1] ProcessGroupNCCL watchdog thread joined.
[rank2]:[I1219 14:16:20.323780605 ProcessGroupNCCL.cpp:1271] [PG ID 0 PG GUID 0 Rank 2] ProcessGroupNCCL heart beat monitor thread joined.
[rank1]:[I1219 14:16:20.323787105 ProcessGroupNCCL.cpp:1271] [PG ID 0 PG GUID 0 Rank 1] ProcessGroupNCCL heart beat monitor thread joined.
[rank0]:[I1219 14:16:20.323810413 TCPStoreLibUvBackend.cpp:119] [c10d - debug] Read callback failed. code:-4095 name:EOF desc:end of file
[rank0]:[I1219 14:16:20.323859105 TCPStoreLibUvBackend.cpp:119] [c10d - debug] Read callback failed. code:-4095 name:EOF desc:end of file
[rank5]:[I1219 14:16:20.387344951 ProcessGroupNCCL.cpp:1246] [PG ID 0 PG GUID 0 Rank 5] ProcessGroupNCCL destructor entered.
[rank5]:[I1219 14:16:20.387362982 ProcessGroupNCCL.cpp:1230] [PG ID 0 PG GUID 0 Rank 5] Launching ProcessGroupNCCL abort asynchrounously.
[rank6]:[I1219 14:16:20.387412923 ProcessGroupNCCL.cpp:1246] [PG ID 0 PG GUID 0 Rank 6] ProcessGroupNCCL destructor entered.
[rank6]:[I1219 14:16:20.387431829 ProcessGroupNCCL.cpp:1230] [PG ID 0 PG GUID 0 Rank 6] Launching ProcessGroupNCCL abort asynchrounously.
[rank5]:[I1219 14:16:20.387946437 ProcessGroupNCCL.cpp:1116] [PG ID 0 PG GUID 0 Rank 5] future is successfully executed for: ProcessGroup abort
[rank5]:[I1219 14:16:20.387951199 ProcessGroupNCCL.cpp:1237] [PG ID 0 PG GUID 0 Rank 5] ProcessGroupNCCL aborts successfully.
[rank5]:[I1219 14:16:20.387954920 ProcessGroupNCCL.cpp:1267] [PG ID 0 PG GUID 0 Rank 5] ProcessGroupNCCL watchdog thread joined.
[rank6]:[I1219 14:16:20.387959319 ProcessGroupNCCL.cpp:1116] [PG ID 0 PG GUID 0 Rank 6] future is successfully executed for: ProcessGroup abort
[rank6]:[I1219 14:16:20.387962662 ProcessGroupNCCL.cpp:1237] [PG ID 0 PG GUID 0 Rank 6] ProcessGroupNCCL aborts successfully.
[rank0]:[I1219 14:16:20.343798546 TCPStoreLibUvBackend.cpp:119] [c10d - debug] Read callback failed. code:-4095 name:EOF desc:end of file
[rank6]:[I1219 14:16:20.387968803 ProcessGroupNCCL.cpp:1267] [PG ID 0 PG GUID 0 Rank 6] ProcessGroupNCCL watchdog thread joined.
[rank5]:[I1219 14:16:20.387983615 ProcessGroupNCCL.cpp:1271] [PG ID 0 PG GUID 0 Rank 5] ProcessGroupNCCL heart beat monitor thread joined.
[rank0]:[I1219 14:16:20.343813505 TCPStoreLibUvBackend.cpp:119] [c10d - debug] Read callback failed. code:-4095 name:EOF desc:end of file
[rank6]:[I1219 14:16:20.387983828 ProcessGroupNCCL.cpp:1271] [PG ID 0 PG GUID 0 Rank 6] ProcessGroupNCCL heart beat monitor thread joined.
--------------------------------------------------------------------------
prterun detected that one or more processes exited with non-zero status,
thus causing the job to be terminated. The first process to do so was:

   Process name: [prterun-farm22-gpu0304-47532@1,6] Exit code:    1
--------------------------------------------------------------------------

also the nccl log

farm22-gpu0303:47733:48605 [2] misc/ipcsocket.cc:221 NCCL WARN UDS: Sending data over socket /tmp/nccl-socket-5-66077aa4ca9712b failed : Connection refused (111)

farm22-gpu0303:47733:48605 [2] NCCL INFO proxy.cc:1111 -> 2

farm22-gpu0303:47733:48605 [2] proxy.cc:1122 NCCL WARN ncclProxyCallBlockingUDS call to tpRank 5(66077aa4ca9712b) failed : 2

farm22-gpu0303:47733:48605 [2] NCCL INFO proxy.cc:1132 -> 2

farm22-gpu0303:47733:48605 [2] proxy.cc:1140 NCCL WARN ncclProxyClientGetFd call to tpRank 5 handle 0x14ee58041500 failed : 2

farm22-gpu0303:47733:48605 [2] NCCL INFO transport/p2p.cc:249 -> 2

farm22-gpu0303:47733:48605 [2] NCCL INFO transport/p2p.cc:330 -> 2

farm22-gpu0303:47733:48605 [2] NCCL INFO transport/p2p.cc:460 -> 2

farm22-gpu0303:47733:48605 [2] NCCL INFO transport.cc:165 -> 2

farm22-gpu0303:47733:48605 [2] NCCL INFO init.cc:1263 -> 2

farm22-gpu0303:47733:48605 [2] NCCL INFO init.cc:1548 -> 2

farm22-gpu0303:47733:48605 [2] NCCL INFO group.cc:64 -> 2 [Async thread]

farm22-gpu0303:47733:48645 [2] proxy.cc:1546 NCCL WARN [Service thread] Error encountered progressing operation=Cofarm22-gpu0303:47734:48606 [3] NCCL INFO Imported shareable buffer device 3 size 10485760 ptr 0x14ddbca00000

farm22-gpu0303:47734:48606 [3] NCCL INFO ProxyCall UDS comm 0x1836cbd0 rank 3 tpRank 6(32145f9c4bc28b82) reqSize 8 respSize 0 respFd 0x14dde030eca8 opId 0xf51804ff97c439cf

farm22-gpu0303:47734:48644 [3] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer farm22-gpu0304.internal.sanger.ac.uk<54066>

farm22-gpu0303:47734:48644 [3] NCCL INFO misc/socket.cc:752 -> 6

farm22-gpu0303:47734:48644 [3] NCCL INFO transport/net_ib.cc:1207 -> 6

farm22-gpu0303:47734:48644 [3] NCCL INFO transport/net.cc:837 -> 6

farm22-gpu0303:47734:48644 [3] NCCL INFO proxyProgressAsync opId=0x14dd8ad7b540 op.type=4 op.reqBuff=0x14dd6005af00 op.respSize=21040 done

farm22-gpu0303:47734:48644 [3] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer farm22-gpu0304.internal.sanger.ac.uk<40172>

farm22-gpu0303:47734:48644 [3] NCCL INFO misc/socket.cc:752 -> 6

farm22-gpu0303:47734:48644 [3] NCCL INFO transport/net_ib.cc:1207 -> 6

farm22-gpu0303:47734:48644 [3] NCCL INFO transport/net.cc:837 -> 6

farm22-gpu0303:47734:48644 [3] NCCL INFO proxyProgressAsync opId=0x14dd8ad63760 op.type=4 op.reqBuff=0x14dd60054120 op.respSize=21040 done

6

farm22-gpu0303:47732:48647 [0] proxy.cc:1556 NCCL WARN [Service thread] Could not receive type from localRank 1, res=6, closed=0

farm22-gpu0303:47732:48647 [0] proxy.cc:1580 NCCL WARN [Proxy Service 4] Failed to execute operation Unknown from rank 5, retcode 6

farm22-gpu0303:47732:48647 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer farm22-gpu0304.internal.sanger.ac.uk<46696>

farm22-gpu0303:47732:48647 [0] NCCL INFO misc/socket.cc:752 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO transport/net_ib.cc:1207 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO transport/net.cc:837 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO proxyProgressAsync opId=0x15113ad99198 op.type=4 op.reqBuff=0x15111009b550 op.respSize=21040 done

farm22-gpu0303:47732:48647 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer farm22-gpu0304.internal.sanger.ac.uk<38510>

farm22-gpu0303:47732:48647 [0] NCCL INFO misc/socket.cc:752 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO transport/net_ib.cc:1207 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO transport/net.cc:837 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO proxyProgressAsync opId=0x15113ad44c28 op.type=4 op.reqBuff=0x151110054110 op.respSize=21040 done

farm22-gpu0303:47732:48647 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer farm22-gpu0304.internal.sanger.ac.uk<35398>

farm22-gpu0303:47732:48647 [0] NCCL INFO misc/socket.cc:752 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO transport/net_ib.cc:1207 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO transport/net.cc:837 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO proxyProgressAsync opId=0x15113ad697b8 op.type=4 op.reqBuff=0x15111005af80 op.respSize=21040 done

farm22-gpu0303:47732:48647 [0] misc/socket.cc:50 NCCL WARN socketProgress: Connection closed by remote peer farm22-gpu0304.internal.sanger.ac.uk<43378>

farm22-gpu0303:47732:48647 [0] NCCL INFO misc/socket.cc:752 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO transport/net_ib.cc:1207 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO transport/net.cc:837 -> 6

farm22-gpu0303:47732:48647 [0] NCCL INFO proxyProgressAsync opId=0x15113ad81598 op.type=4 op.reqBuff=0x15111005afa0 op.respSize=21040 done

thank you increasing ulimit worked.

1 Like