Multi-node error on process destruction : CUDA error: invalid device ordinal

Hi,
I’m running distributed code on a multi-node setup using torch.distributed with NCCL backend and multiple process groups.
Everything works fine until process group destruction. I get the error Exception thrown when waiting for future ProcessGroup abort: CUDA error: invalid device ordinal
on every rank that has different rank and local_rank (i.e. all nodes but one), even though my cuda device is correctly set to local_rank.
Here is a minimal script to reproduce the bug:

local_rank = int(os.getenv("LOCAL_RANK"))
rank = int(os.getenv("RANK"))
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
ws = dist.get_world_size()

members = list(range(ws))
group = dist.new_group(ranks = members)

x = torch.full((2, 3), rank, device=local_rank)
if rank != 0:
	dist.recv(x, src = rank - 1, group = group)
if rank != ws - 1:
	dist.send(x, dst = rank + 1, group = group)

torch.cuda.synchronize()
dist.barrier(group=group)

if rank in members:
	dist.destroy_process_group(group)

dist.destroy_process_group()

started with command:

torchrun --nnodes 2 --nproc-per-node 4 --rdzv-id 40184 --rdzv-backend c10d --rdzv-endpoint x1002c0s3b0n0 script.py

on every node.

Note that it only happens if a second process group is created, and p2p operations happen in that group. Also, both calls to dist.destroy_process_group() are optional, the same error happens without them.

Additional info:

Python 3.11.5
torch 2.4.0+cu124.post2
CUDA 12.4.0
NCCL 2.20.5

Tried on 2 different clusters (V100/H100), with both torch 2.4 and 2.5.

Full error (same for ranks 5,6,7):

[rank4]: Traceback (most recent call last):
[rank4]:   File "/net/home/project/tutorial/tutorial051/toy/nccl-test.py", line 99, in <module>
[rank4]:     run_multinode_test(local_rank, rank, ws)
[rank4]:   File "/net/home/project/tutorial/tutorial051/toy/nccl-test.py", line 89, in run_multinode_test
[rank4]:     dist.destroy_process_group(group)
[rank4]:   File "/net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 1763, in destroy_process_group
[rank4]:     _shutdown_backend(pg)
[rank4]:   File "/net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 1458, in _shutdown_backend
[rank4]:     backend._shutdown()
[rank4]: torch.distributed.DistBackendError: [PG 1 Rank 4] Exception thrown when waitng for future ProcessGroup abort: CUDA error: invalid device ordinal
[rank4]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

[rank4]: Exception raised from c10_cuda_check_implementation at /net/home/local/softmgr/tmp/building/pytorchd990dad/c10/cuda/CUDAException.cpp:43 (most recent call first):
[rank4]: frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0xc8 (0x400029a65508 in /net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/lib/libc10.so)
[rank4]: frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xc8 (0x400029a14d1c in /net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/lib/libc10.so)
[rank4]: frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x2cc (0x400029983a4c in /net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
[rank4]: frame #3: c10::cuda::ExchangeDevice(signed char) + 0x5c (0x400029983d90 in /net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
[rank4]: frame #4: c10d::ProcessGroupNCCL::abortCommsFromMap(std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::shared_ptr<c10d::NCCLComm>, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::shared_ptr<c10d::NCCLComm> > > >&, std::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) + 0x170 (0x40001ea5b360 in /net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
[rank4]: frame #5: c10d::ProcessGroupNCCL::abort(std::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >) + 0xb4 (0x40001ea5b764 in /net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
[rank4]: frame #6: <unknown function> + 0xf7bebc (0x40001ea5bebc in /net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
[rank4]: frame #7: <unknown function> + 0xe25ffc (0x40001e905ffc in /net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
[rank4]: frame #8: <unknown function> + 0x11a9c (0x400013a51a9c in /lib64/libpthread.so.0)
[rank4]: frame #9: <unknown function> + 0xf723ec (0x40001ea523ec in /net/scratch/hscra/project/tutorial/tutorial051/venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
[rank4]: frame #10: <unknown function> + 0xdd52c (0x40001534d52c in /net/software/aarch64/el8/GCCcore/13.2.0/lib64/libstdc++.so.6)
[rank4]: frame #11: <unknown function> + 0x875c (0x400013a4875c in /lib64/libpthread.so.0)
[rank4]: frame #12: <unknown function> + 0xdfeec (0x400013b8feec in /lib64/libc.so.6)

Does anyone know what is wrong there ?

1 Like

I am seeing the same issue on AMD platforms when using DeviceMesh. It would appear that using P2POps with batch_isend_irecv does not cause this error during shutdown, however.

1 Like