Single Machine DDP Issue on A6000 GPU

Hi,
I’ve recently gotten access to some A6000 GPUs. The machine has CUDA11.3 installed, and my environment has the latest PyTorch release (1.10.0) with the CUDA11.3 build: torch==1.10.0+cu113.

It seems like single GPU training works well, but as soon as I switch to DDP (initiated when torch.cuda.device_count() > 1 the training does not run.
Interestingly, the GPU memory get occupied by the appropriate amount, and GPU utilization is at 100%. However, the code does not progress from there on. No error messages - it just hangs.

Is this a known bug for A6000 GPUs? I’ve tested my DDP code on a wide range of GPUs (TitanXP, RTX 2080Ti, Quadro RTX 8000, etc) and haven’t faced issues. The code doesn’t even fail (just hangs), so I’m not sure how to start debugging this. Has anyone else faced similar issues?

You could use:

export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

to check if you are running into e.g. a setup issue and NCCL is failing.

Hi @ptrblck, thanks for the suggestion.
I’ve tried to filter out the output of the NCCL info as much as possible (and removed server name / IP):

MYSERVER:84140:84140 [0] NCCL INFO Bootstrap : Using ens8f0:000.00.000.000<0>
MYSERVER:84140:84140 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
MYSERVER:84140:84140 [0] NCCL INFO NET/IB : No device found.
MYSERVER:84140:84140 [0] NCCL INFO NET/Socket : Using [0]ens8f0:000.00.000.000<0>
MYSERVER:84140:84140 [0] NCCL INFO Using network Socket
NCCL version 2.10.3+cuda11.3
MYSERVER:84140:84275 [0] NCCL INFO bootstrap.cc:107 Mem Alloc Size 28 pointer 0x7f8fa8000b20
MYSERVER:84140:84276 [0] NCCL INFO init.cc:260 Mem Alloc Size 18872 pointer 0x7f8fa0002f70
MYSERVER:84140:84276 [0] NCCL INFO misc/utils.cc:30 Mem Alloc Size 12 pointer 0x7f8fa000d390
MYSERVER:84141:84141 [1] NCCL INFO Bootstrap : Using ens8f0:000.00.000.000<0>
MYSERVER:84141:84141 [1] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation
MYSERVER:84141:84277 [1] NCCL INFO === System : maxWidth 24.0 totalWidth 24.0 ===
MYSERVER:84141:84277 [1] NCCL INFO CPU/1 (1/2/-1)
MYSERVER:84141:84277 [1] NCCL INFO + PCI[24.0] - GPU/A1000 (0)
MYSERVER:84141:84277 [1] NCCL INFO + PCI[24.0] - GPU/C1000 (1)
MYSERVER:84141:84277 [1] NCCL INFO + PCI[3.0] - NIC/C2000
MYSERVER:84141:84277 [1] NCCL INFO ==========================================
MYSERVER:84141:84277 [1] NCCL INFO GPU/A1000 :GPU/A1000 (0/5000.000000/LOC) GPU/C1000 (2/24.000000/PHB) CPU/1 (1/24.000000/PHB)
MYSERVER:84141:84277 [1] NCCL INFO GPU/C1000 :GPU/A1000 (2/24.000000/PHB) GPU/C1000 (0/5000.000000/LOC) CPU/1 (1/24.000000/PHB)
MYSERVER:84141:84277 [1] NCCL INFO Pattern 4, crossNic 0, nChannels 2, speed 12.000000/12.000000, type PHB/PIX, sameChannels 1
MYSERVER:84141:84277 [1] NCCL INFO  0 : GPU/0 GPU/1
MYSERVER:84141:84277 [1] NCCL INFO  1 : GPU/0 GPU/1
MYSERVER:84141:84277 [1] NCCL INFO Pattern 1, crossNic 0, nChannels 2, speed 22.000000/22.000000, type PHB/PIX, sameChannels 0
MYSERVER:84141:84277 [1] NCCL INFO  0 : GPU/0 GPU/1
MYSERVER:84141:84277 [1] NCCL INFO  1 : GPU/1 GPU/0
MYSERVER:84141:84277 [1] NCCL INFO Pattern 3, crossNic 0, nChannels 2, speed 22.000000/22.000000, type PHB/PIX, sameChannels 0
MYSERVER:84141:84277 [1] NCCL INFO  0 : GPU/0 GPU/1
MYSERVER:84141:84277 [1] NCCL INFO  1 : GPU/1 GPU/0

The rest seems to be info about memory allocation, something with β€œtrees”, β€œgraphs” and β€œrings”.

There is no other mention about NCCL failing. The last line printed is:
MYSERVER:84140:84140 [0] NCCL INFO group.cc:306 Mem Alloc Size 8 pointer 0x6814150
and the code hangs after that.

Sorry for the hassle!

Thanks for the update. Unfortunately, this doesn’t point to any issues, so could you post a minimal, executable code snippet, which would reproduce the hang, please?

This is an extremely minimal code that won’t run on A6000 with DDP:

import os
import socket

import torch
import torch.multiprocessing as mp
import torch.nn as nn
from torch import distributed as dist
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm


def find_free_port():
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind(("", 0))
    port = sock.getsockname()[1]
    sock.close()
    return port


def main():
    num_available_gpus = torch.cuda.device_count()
    assert num_available_gpus > 1, "Use more than 1 GPU for DDP"
    world_size = num_available_gpus
    os.environ['MASTER_ADDR'] = str("localhost")
    os.environ['MASTER_PORT'] = str(find_free_port())

    print("Distributed Test Code")
    print(f"GPU name: {torch.cuda.get_device_name()}")
    print(f"CUDA capability: {torch.cuda.get_device_capability()}")
    print(f"Num GPUs: {num_available_gpus}")

    mp.spawn(main_process, nprocs=num_available_gpus, args=(num_available_gpus,))


def main_process(gpu, world_size):
    dist.init_process_group(backend="nccl", init_method="env://", world_size=world_size, rank=gpu)
    torch.cuda.set_device(gpu)
    model = nn.Linear(10, 10).cuda(gpu)
    model = DistributedDataParallel(model, device_ids=[gpu])

    criterion = nn.L1Loss()

    for epoch in range(10):
        for iteration in tqdm(range(1000), desc=f"Ep: {epoch}"):
            for p in model.parameters():
                p.grad = None
            x = torch.randn(2, 10).cuda(gpu)
            gt = torch.randn(2, 10).cuda(gpu)
            out = model(x)
            loss = criterion(out, gt)
            loss.backward()


if __name__ == '__main__':
    main()

Testing this on my RTX 2080Ti server, I get the following output:

Distributed Test Code
GPU name: GeForce RTX 2080 Ti
CUDA capability: (7, 5)
Num GPUs: 2
Ep: 0: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1000/1000 [00:02<00:00, 434.17it/s
Ep: 1: ...
...

On the A6000 server, I get the following:

Distributed Test Code
GPU name: NVIDIA RTX A6000
CUDA capability: (8, 6)
Num GPUs: 2
Ep: 0:   0%|                                                                                   | 1/1000 [00:00<13:14,  1.26it/s]

And the code will not progress any further.

It seems like the issue is not just this machine, but other machines in the cluster as well.

This sounds rather like a setup issue as I’m also able to run your code.

EDIT: try to attach to the hanging process via gdb and print the backtrace, which should point to the operation which hangs and could help in isolating the issue.

Upon cancelling the hang, I get a timeout error:

fd_event_list = self._selector.poll(timeout)

It seems like the processes can’t communicate with each other. So I’ve tried wrapping the training loop with model.no_sync(), and the code progresses well.

As you mentioned, this definitely does look like a setup issue.

Any ideas what might be causing communication issues locally?
By the way, thanks for you help :smiley:

Edit I have found that using the gloo backend is a temporary fix to this issue. However, there is no noticeable speedup when using single GPU vs 2 GPUs w/ Gloo. Usually there is quite a noticeable speedup.

Edit 2 Disabling P2P with export NCCL_P2P_DISABLE=1 seems to solve the issue as well, using direct shared memory.

It’s didn’t work :sweat_smile: did you try another method to fix it?

Have you tried export NCCL_P2P_DISABLE=1??

Yes, I write it to my bash file and input it in terminal. It didn’t work

Distributed Test Code
GPU name: NVIDIA RTX A6000
CUDA capability: (8, 6)
Num GPUs: 8
Traceback (most recent call last):
  File "test/multi_gpu_test.py", line 56, in <module>
    main()
  File "test/multi_gpu_test.py", line 33, in main
    mp.spawn(main_process, nprocs=num_available_gpus, args=(num_available_gpus,))
  File "/home/dong.sun/miniconda3/envs/.venv/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 240, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/dong.sun/miniconda3/envs/.venv/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 198, in start_processes
    while not context.join():
  File "/home/dong.sun/miniconda3/envs/.venv/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 160, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
torch.multiprocessing.spawn.ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/dong.sun/miniconda3/envs/.venv/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/home/dong.sun/test/multi_gpu_test.py", line 39, in main_process
    model = nn.Linear(10, 10).cuda(gpu)
  File "/home/dong.sun/miniconda3/envs/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 689, in cuda
    return self._apply(lambda t: t.cuda(device))
  File "/home/dong.sun/miniconda3/envs/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 602, in _apply
    param_applied = fn(param)
  File "/home/dong.sun/miniconda3/envs/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py", line 689, in <lambda>
    return self._apply(lambda t: t.cuda(device))
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

It seems like you’re running into OOM issue, which is not the topic of this post.

1 Like

I also had the same problem on A6000, NCCL_P2P_DISABLE=1 helped me to solve it.

Could you please share if you have figured out the fundamental cause?

I found this on a forum a few years ago.
Check if you have ACS enabled by sudo lspci -vvv | grep ACSCtl. You should get values such as SrcValid-… if you have any that are + instead of - then apparently it means ACS is enabled.
Then you can disable ACS via the following script:

#!/bin/bash
#
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
#

# must be root to access extended PCI config space
if [ "$EUID" -ne 0 ]; then
  echo "ERROR: $0 must be run as root"
  exit 1
fi

for BDF in `lspci -d "*:*:*" | awk '{print $1}'`; do

    # skip if it doesn't support ACS
    setpci -v -s ${BDF} ECAP_ACS+0x6.w > /dev/null 2>&1
    if [ $? -ne 0 ]; then
            #echo "${BDF} does not support ACS, skipping"
            continue
    fi

    logger "Disabling ACS on `lspci -s ${BDF}`"
    setpci -v -s ${BDF} ECAP_ACS+0x6.w=0000
    if [ $? -ne 0 ]; then
        logger "Error disabling ACS on ${BDF}"
            continue
    fi
    NEW_VAL=`setpci -v -s ${BDF} ECAP_ACS+0x6.w | awk '{print $NF}'`
    if [ "${NEW_VAL}" != "0000" ]; then
        logger "Failed to disable ACS on ${BDF}"
            continue
    fi
done
exit 0
1 Like

Mounting the NVLink device fixed the problem.

For anyone in an environment where NVLink cannot be installed, this solution may help.