MPI backend and GPU tensor error

Hi,
I cannot use CUDA tensor with MPI

I have compiled OpenMPI with

./configure --prefix=/usr/local --with-cuda=${CUDA_HOME} --with-hwloc=internal --with-libevent=internal --with-ucc=/usr/local --with-ucx=/usr/local --enable-mpi-thread-multiple --enable-mpi-cx

and then pyTorch (torch-2.1.0a0+git0af3203)

export CUDA_NVCC_EXECUTABLE=“/usr/local/cuda-12/bin/nvcc”
export CUDA_HOME=“/usr/local/cuda-12”
export CUDNN_INCLUDE_PATH=“/usr/local/cuda-12/include/”
export CUDNN_LIBRARY_PATH=“/usr/local/cuda-12/lib64/”
export LIBRARY_PATH=“/usr/local/cuda-12/lib64”
export USE_CUDA=1
export USE_CUDNN=1
export USE_MKLDNN=1
export ATEN_NO_TEST=1
export BUILD_ONNX_PYTHON=1
export USE_TENSORRT=0
export USE_FFMPEG=1
export USE_DISTRIBUTED=1
export USE_GLOO=1
export USE_MPI=1
export USE_TENSORPIPE=1
export USE_OPENCV=1
export USE_UCC=1
export USE_SYSTEM_UCC=1
export USE_NUMPY=1
export USE_OPENCL=0
export USE_MPI=1
export USE_ZMQ=1
export USE_FFTW=1
export USE_MPS=1
export BUILD_TEST=0
export CMAKE_CUDA_COMPILER=“/usr/local/cuda-12/bin/nvcc”
python setup.py bdist_wheel

After that, I tried the example from https://github.com/Stonesjtu/pytorch-learning/blob/master/build-with-mpi.md which is

import torch
import torch.distributed as dist
dist.init_process_group(backend='mpi')
t = torch.zeros(5,5).fill_(dist.get_rank()).cuda()
dist.all_reduce(t)

and I got the error RuntimeError: No backend type associated with device type cuda
But if I remove the .cuda() it works. Nevertheless, I would like to use MPI with my GPUs.
I checked with ldd (ldd /usr/local/lib/python3.10/site-packages/torch/*.so | grep -E "libopen-rte|libmpi):

   libmpi_cxx.so.40 => /lib/x86_64-linux-gnu/libmpi_cxx.so.40 (0x00007ff6f0d6e000)
   libmpi.so.40 => /usr/local/lib/libmpi.so.40 (0x00007ff6f0c2f000)
   libopen-rte.so.40 => /usr/local/lib/libopen-rte.so.40 (0x00007ff6dc5d6000)

I checked MPI with CUDA (https://gist.githubusercontent.com/K-Wu/6c353273aafe9a4eaaa344a8b74475b6/raw/bd7ea29cdd72801442ad6d69a88be5fee67cf534/mpi_cuda_awareness_check.cpp AND GitHub - NVIDIA/nccl-tests: NCCL Tests) which seems to work.

Can you help me, please?

I’m not sure if we generally test/support this use case - i’m not saying it should not work, but i wonder if you have a reason to prefer MPI+Cuda instead of NCCL+Cuda? NCCL is the comm lib that is commonly used with cuda for pytorch.

Ok but MPI also support NCCL. It a simple and easy way to make a code compatible to distribute a task for CPU and GPU.
Moreover it’s better when it comes to distribute on CPU and GPU at the same time.

Hi, from the above error, it may be worth trying:

dist.init_process_group(backend="cuda:mpi")

to force a mapping from CUDA tensor to MPI.

Thanks !

For those with the same issue, you also have to define RANK and WORLD_SIZE for rendez-vous.
With OpenMPI, something like

import os
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
torch.distributed.init_process_group(backend='cpu:mpi,cuda:mpi')