Segfault using cuda with openmpi

Hi there !
I am following the tutorial for writing distributed applications.
Here is the basic example code. Everything works well as long as I don’t try using cuda.
Is there a specific step to enable it that I missed ? I am getting a segfault…
It might also be linked with some kind of permission issue since the signal code is “Invalid Permissions”.
Any help would be really appreciated !

import os
import torch
import torch.distributed as dist
import platform


def run(rank, size):
    tensor = torch.zeros(1).cuda(0)
    if rank == 0:
        # Send the tensor to process 1
        tensor += 1
        dist.send(tensor=tensor, dst=1)
    else:
        # Receive tensor from process 0
        dist.recv(tensor=tensor, src=0)
    print('Rank ', rank, ' has data ', tensor[0])


def init_processes(fn):
    """ Initialize the distributed environment. """
    dist.init_process_group('mpi')
    rank = dist.get_rank()
    size = dist.get_world_size()
    print('I am rank ', rank, ' on ', platform.node())
    fn(rank, size)


if __name__ == "__main__":
    init_processes(run)

And the command outputs:

$ mpiexec -np 2 python main.py

WARNING: Linux kernel CMA support was requested via the
btl_vader_single_copy_mechanism MCA variable, but CMA support is
not available due to restrictive ptrace settings.

The vader shared memory BTL will fall back on another single-copy
mechanism if one is available. This may result in lower performance.

  Local host: iccluster131
--------------------------------------------------------------------------
I am rank  0  on  iccluster131
I am rank  1  on  iccluster131
[iccluster131:19580] *** Process received signal ***
[iccluster131:19580] Signal: Segmentation fault (11)
[iccluster131:19580] Signal code: Invalid permissions (2)
[iccluster131:19580] Failing at address: 0x420a900000
[iccluster131:19580] [ 0] /lib/x86_64-linux-gnu/libpthread.so.0(+0x11390)[0x7fda4b7ae390]
[iccluster131:19580] [ 1] /lib/x86_64-linux-gnu/libc.so.6(+0x14db15)[0x7fda4ac08b15]
[iccluster131:19580] [ 2] /home/me/.conda/envs/pytorch-env/lib/./libopen-pal.so.40(opal_convertor_pack+0x175)[0x7fda1ae5e925]
[iccluster131:19580] [ 3] /home/me/.conda/envs/pytorch-env/lib/openmpi/mca_btl_vader.so(mca_btl_vader_sendi+0x383)[0x7fda10799803]
[iccluster131:19580] [ 4] /home/me/.conda/envs/pytorch-env/lib/openmpi/mca_pml_ob1.so(+0xb6db)[0x7fda0bde56db]
[iccluster131:19580] [ 5] /home/me/.conda/envs/pytorch-env/lib/openmpi/mca_pml_ob1.so(mca_pml_ob1_send+0x690)[0x7fda0bde7120]
[iccluster131:19580] [ 6] /home/me/.conda/envs/pytorch-env/lib/libmpi.so.40(PMPI_Send+0xf2)[0x7fda254bc882]
[iccluster131:19580] [ 7] /home/me/.conda/envs/pytorch-env/lib/python3.6/site-packages/torch/_C.cpython-36m-x86_64-linux-gnu.so(_Z15THDPModule_sendP7_objectS0_+0xd5)[0x7fda42831e95]
[iccluster131:19580] [ 8] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(_PyCFunction_FastCallDict+0x18e)[0x7fda4ba8660e]
[iccluster131:19580] [ 9] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(+0x16669a)[0x7fda4bb2069a]
[iccluster131:19580] [10] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(_PyEval_EvalFrameDefault+0x4186)[0x7fda4bb25046]
[iccluster131:19580] [11] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(+0x16629e)[0x7fda4bb2029e]
[iccluster131:19580] [12] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(+0x1665b2)[0x7fda4bb205b2]
[iccluster131:19580] [13] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(_PyEval_EvalFrameDefault+0x3bd8)[0x7fda4bb24a98]
[iccluster131:19580] [14] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(+0x165930)[0x7fda4bb1f930]
[iccluster131:19580] [15] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(+0x166854)[0x7fda4bb20854]
[iccluster131:19580] [16] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(_PyEval_EvalFrameDefault+0x4186)[0x7fda4bb25046]
[iccluster131:19580] [17] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(+0x165930)[0x7fda4bb1f930]
[iccluster131:19580] [18] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(+0x166854)[0x7fda4bb20854]
[iccluster131:19580] [19] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(_PyEval_EvalFrameDefault+0x4186)[0x7fda4bb25046]
[iccluster131:19580] [20] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(+0x16629e)[0x7fda4bb2029e]
[iccluster131:19580] [21] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(PyEval_EvalCodeEx+0x6d)[0x7fda4bb208cd]
[iccluster131:19580] [22] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(PyEval_EvalCode+0x3b)[0x7fda4bb2091b]
[iccluster131:19580] [23] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(PyRun_FileExFlags+0xb2)[0x7fda4bb5b472]
[iccluster131:19580] [24] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(PyRun_SimpleFileExFlags+0xe7)[0x7fda4bb5b5d7]
[iccluster131:19580] [25] /home/me/.conda/envs/pytorch-env/bin/../lib/libpython3.6m.so.1.0(Py_Main+0xf2c)[0x7fda4bb766dc]
[iccluster131:19580] [26] python(main+0x16e)[0x400bce]
[iccluster131:19580] [27] /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xf0)[0x7fda4aadb830]
[iccluster131:19580] [28] python[0x400c95]
[iccluster131:19580] *** End of error message ***
-------------------------------------------------------
Primary job  terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.
-------------------------------------------------------
--------------------------------------------------------------------------
mpiexec noticed that process rank 0 with PID 0 on node iccluster131 exited on signal 11 (Segmentation fault).
--------------------------------------------------------------------------
[iccluster131:19575] 1 more process has sent help message help-btl-vader.txt / cma-permission-denied
[iccluster131:19575] Set MCA parameter "orte_base_help_aggregate" to 0 to see all help / error messages

My mpi version:

mpiexec --version
mpiexec (OpenRTE) 3.0.0

with which pytorch was compiled.

python --version
Python 3.6.3 :: Anaconda, Inc.

python -c 'import torch;print(torch.__version__)'
0.4.0a0+0ab68b8

Ok, solved it.
Since I was a bit lost due to lack of documentation for using pytorch + mpi + gpu, I will give here the steps I followed.
The main thing I was missing is that I needed a openmpi which is "cuda-aware".

Main steps to follow:

  1. Install “cuda-aware” openMPI : Need to compile from source
  2. Install Pytorch from source

Step by step:

1. Install openMPI --with-cuda

If you have openMPI and you want to check if it is “cuda-aware”, run:

ompi_info --parsable --all | grep mpi_built_with_cuda_support:value

If you get true, perfect, nothing to do.
If you get false, too bad, you need to recompile it.
(source : Link)

At this step, if you have opemMPI installed (or any other MPI implementation), I strongly advise to uninstall it in order not to mess things up with the paths …

Then, download the last openMPI version here and extract it:

wget https://www.open-mpi.org/software/ompi/v3.0/downloads/openmpi-3.0.0.tar.gz
gunzip -c openmpi-3.0.0.tar.gz | tar xf -

We then follow the steps from here, but we add the --with-cuda parameter to the ./configure command as explained here:

cd openmpi-3.0.0
./configure --prefix=/home/$USER/.openmpi --with-cuda
make all install

The --prefix parameter is the install path and is mandatory, Note that you need to choose a directory where you have write permissions (I didn’t have it in /usr/local, the one suggested in the link).

Now you can prepare yourself a cup of coffee (as proposed here) since this takes around 15 min.

Once that is done, you need to add to your path the openMPI bin directory, and the lib directory to the lib path :

 export PATH="$PATH:/home/$USER/.openmpi/bin"
 export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/home/$USER/.openmpi/lib/"

I recommend to add it to your .bashrc/.zshrc/… straight away.

Check that it’s working, this is more or less what you should get:

mpirun --version
> mpirun (Open MPI) 3.0.0

ompi_info --parsable --all | grep mpi_built_with_cuda_support:value
> mca:mpi:base:param:mpi_built_with_cuda_support:value:true

2. Install pytorch (from source)

Now, you just need to install pytorch from source (remove it properly and entirely first).
Be sure to run first:
conda update conda
Then, copying for convenience from github:

export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory]

# Install basic dependencies
conda install numpy pyyaml mkl setuptools cmake cffi

# Add LAPACK support for the GPU
conda install -c pytorch magma-cuda80 # or magma-cuda75 if CUDA 7.5

Note that the pytorch read-me should probably be updated, since you need to use the pytorch channel in anaconda to get the last version of magma (that was the case for me).

And finally:

git clone --recursive https://github.com/pytorch/pytorch
cd pytorch && python setup.py install

During compilation you should see something like the following, which most probably means that pytorch found our installation of openMPI:

Found openMP
Compiling with openMP support

I guess Pytorch automatically detects our installation since it is in our path, so be sure to have it before compiling. Be sure you don’t have two MPI installations, and it is the one you just installed that you run with mpirun or mpiexec.

Hope this may help others… (if it was obvious for some, it was not really me :stuck_out_tongue:)


Note about using Cuda + openMPI with pytorch:
I had to manually set the device with torch.set_device(x). If I just use .cuda(x), it crashes with the same cryptic error message.

5 Likes

I follow the step to compile the MPI, but I meet the error

RuntimeError: the MPI backend is not available; try to recompile the THD package with MPI support at /opt/conda/conda-bld/pytorch_1513368888240/work/torch/lib/THD/process_group/General.cpp:17

When I run the example in your first quesiton.

It looks like pytorch was not compiled with MPI… did you recompile pytorch as well ?
Did you check the pytorch you’re using is the one you recompiled ?
Make sure pytorch can find MPI during installation by adding the bin folder to the path.
If you didn’t see:

Found openMP
Compiling with openMP support

during install, then probably pytorch didn’t find MPI.
Good luck !

I cannot import torch, after recompling. Just follow the guide, I will try again.

Hi theevann,

I have some question.

Thanks to you, I installed successfully openmpi with cude and tested it by “ompi_info --parsable --all | grep mpi_built_with_cuda_support:value”. (I installed it in ~/lib/openmpi_cuda.)
Also I set PATH and LD_LIBRARY_PATH for it to be first executable and library.
I checked as following:
which mpirun
~/lib/openmpi_cuda/bin/mpirun

But when I install pytorch from source code (https://github.com/pytorch/pytorch) with OpenMPI, it doesn’t use the OpenMPI (~/lib/openmpi_cuda).

Instead of it, it refer the OpenMPI installed by “conda install -c conda-forge openmpi” which is NOT CUDA-aware MPI.

So how can I connect the CUDA-aware OpenMPI I installed when I install pytorch (with MPI) from source code?

Thanks.

Hi hyu,

I didn’t try to install pytorch with 2 versions of openmpi.
Maybe that pytorch is taking the first version it finds in the library path … did you try to put the path for your cuda-aware openmpi in the beginning of your paths (PATH and LD_LIBRARY_PATH) ?

Yes. I have already set PATH & LD_LIBRARY_PATH

“Thanks to you, I installed successfully openmpi with cude and tested it by “ompi_info --parsable --all | grep mpi_built_with_cuda_support:value”. (I installed it in ~/lib/openmpi_cuda.)
Also I set PATH and LD_LIBRARY_PATH for it to be first executable and library.
I checked as following:
which mpirun
~/lib/openmpi_cuda/bin/mpirun”

i built pytorch as your instruction succesfully, but when i use it. some error happened.

error:
pml_ucx.c:226 Error: UCP worker does not support MPI_THREAD_MULTIPLE

test code:

import torch
import torch.distributed as dist

dist.init_process_group(backend=‘mpi’)

t = torch.zeros(5,5).fill_(dist.get_rank()).cuda()

dist.all_reduce(t)

Dear theevann,
I am trying to implement inter-GPU communiation by using pytorch+mpi+gpu. I have run your code; it runs successfully.
In addition, I have modified your code a little to make that process0 runs on GPU0 and process1 runs on GPU1.
However, I find that the modified code can not run normally because of “Segmentation Fault”. Do you know why?

Following are the modified code.

import os
import socket
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import platform

def run(rank, size):

if rank == 0:
    tensor = torch.zeros(1).cuda(0)
    # Send the tensor to process 1
    tensor += 1
    dist.send(tensor=tensor, dst=1)
else:
    tensor = torch.zeros(1).cuda(1)
    # Receive tensor from process 0
    dist.recv(tensor=tensor, src=0)
print('Rank ', rank, ' has data ', tensor[0])

def init_processes(fn):
“”" Initialize the distributed environment. “”"
dist.init_process_group(‘mpi’)
rank = dist.get_rank()
size = dist.get_world_size()
print('I am rank ', rank, ’ on ', platform.node())
fn(rank, size)

if name == “main”:
init_processes(run)

Your help will be appreciated. Thank you.

I’m not sure if that is your problem, but using .cuda was failing for me too and I had to use the set_device function.
Did you try it ?

Hi theevann, for process 1, manually set the device will be ok. Thank you.

Ciao Evann,

I followed exactly the same steps as described here with a difference of openmpi-3.0.6 and I installed it at /usr/local/ and after installation and path setting appropriately, when I run this ‘ompi_info --parsable --all | grep mpi_built_with_cuda_support:value’ it still comes out false. Dont know what’s going on. Any suggestion? the error is :
[user-HP-ZBook-Create-G7-Notebook-] *** Process received signal ***
[user-HP-ZBook-Create-G7-Notebook-] Signal: Segmentation fault (11)
[user-HP-ZBook-Create-G7-Notebook-] Signal code: Invalid permissions (2)
[user-HP-ZBook-Create-G7-Notebook-] Failing at address: 0x7f890af2fe20
[user-HP-ZBook-Create-G7-Notebook-] [ 0] /lib/x86_64-linux-gnu/libpthread.so.0(+0x153c0)[0x7f890ad7d3c0]
[user-HP-ZBook-Create-G7-Notebook-] [ 1] /usr/local/.openmpi/lib/libopen-rte.so.40(orte_errmgr_base_framework+0x0)[0x7f890af2fe20]
[user-HP-ZBook-Create-G7-Notebook-] *** End of error message ***

Hello,

To me there is two possibilities:

  • Either you have multiple openmpi installed and the wrong one is showing up
  • Either somehow the compilation didn’t work with the FLAG --with-cuda
    (Did you install it on a machine that has cuda installed ?)

Uninstall everything and try again ?
I am afraid I can’t help much more though, sorry…

Thank you very much for all the information on how to get it running, you saved my life.
I would like to add that it is needed to do torch.cuda.set_device(device) to select a default GPU for each process, otherwise the result of the distributed operations will end always in cuda:0, which is the default device. Here I include a working example of using asynchronous send and receive between two processes:

"""
run with:
    mpirun -np 2 python example.py
"""


import os
import torch
import torch.distributed as dist

def run(devices):
    world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    assert len(devices) == world_size
    rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
    device = devices[rank]
    if device != torch.device('cpu'):
        torch.cuda.set_device(device)
    for _ in range(10):
        if rank == 0:
            # Send the tensor to process 1
            tensor0 = torch.rand(1000000, device=device)
            send_req = dist.isend(tensor0, dst=1)
            # Receive the tensor from process 0
            tensor1 = torch.zeros(1000000, device=device)
            recv_req = dist.irecv(tensor1, src=1)
        else:
            # Send the tensor to process 0
            tensor1 = torch.rand(1000000, device=device)
            send_req = dist.isend(tensor1, dst=0)
            # Receive the tensor from process 1
            tensor0 = torch.zeros(1000000, device=device)
            recv_req = dist.irecv(tensor0, src=0)
        send_req.wait()
        recv_req.wait()
        if rank==0:
            print(rank, tensor0.mean(), tensor1.mean())
        if rank==1:
            print(rank, tensor0.mean(), tensor1.mean())


def init_processes(fn, devices):
    """ Initialize the distributed environment. """
    dist.init_process_group('mpi')
    fn(devices)


if __name__ == "__main__":
    devices = [torch.device('cuda:0'), torch.device('cuda:1')]
    init_processes(run, devices)