How to properly use distributed pytorch with infiniband support

I’m using pytorch on a cluster connected by infiniband(56Gb FDR).
I want to run a distributed training, where each process controls one GPU and the gradients are averaged cross processes by ‘allreduce’(I’m using mpi backend). I except this should scale well just like mpi-based caffe with Inifiniband support.

So I build pytorch from source and WITH_DISTRIBUTED=1, also i’m sure that the MPI libraries are build with Infiniband support(work well with mpi-based caffe). I expect this run faster on 8 GPUs than using ‘DataParallel’, bypassing GIL issues. But actually the performance was poorer.

After some profiling I found that the bottleneck is in ‘allreduce’, which should be faster with infiniband. To ensure that the communication is running through IB, I test the point-to-point bandwidth using dist.send/recv. It’s about 3.7GB/s, which is weird, since Ethernet generally cannot reach this number, while it is only about half of the theoretical bandwidth of Infiniband(I also test the bandwidth using osu benchmark, which shows 11GB/s and 6GB/s intra or inter node).

Here are my numbers:
On Titan Xp, resnet50 with batchsize 32 per GPU runs for 0.15s per iteration for 1 GPU and 0.3s per iteration for 8 GPU when using ‘DataParallel’.
When using 8 processes and average gradients by ‘allreduce’ after loss.backward, it runs for 0.45s per iteration, which is even slower than ‘DataParallel’. It cost about 0.3s in allreduce.

It seems that MPI is not working properly with pytorch under my setttings. Am I missed something? I have searched and tried for 2 days but still cannot get it work, really appreciated for any help!

PS: I also tried the ‘gloo’ backend, with the infiniband patch issued here, it could run but performed even poorer(similar to ‘tcp’ backend).

3 Likes

We would recommend to use Gloo backend, which current has some known issues for IB and we are currently actively working on it.

1 Like

MPI backend isn’t that well-tested, and depending on what MPI version you are using, the thread-support in MPI also has known perf issues and being not well-tested. PyTorch’s MPI backend currently uses MPI threads. That’s why we recommend to use Gloo backend.

NCCL2 is another backend we are currently actively working on to support IB.
https://github.com/pytorch/pytorch/pull/3435 is PR and you can check the status there. It will be available hopefully soon.

One of the problems is that MPI backend doesn’t really support CUDA. Apparently we’re missing some checks for this, but it should be easy to add CUDA support. Thanks for the feedback.

Can you please provide us some more details like the script you are using for training (doesn’t have to be the same exact script, just something we can use to benchmarks) and the scripts you used for benchmarking bandwidth?

One of the reasons why you can see it running slower is this piece. You can see that MPI doesn’t allow us to run all-reduce in-place, so we have to do this extra memcpy that will lower the bandwidth. Additionally, the memcpy is incorrect if tensor is CUDA

Thanks for your reply!

I’m currently testing only using cpu tensors. It seems that the data transfer between process dominates.

I use the following script to test bandwidth(I use slurm to run it):

from __future__ import division
from __future__ import print_function
import os
import torch
import torch.distributed as dist
import argparse
import numpy as np
import time

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--dist-backend', default='gloo', type=str,
                    help='distributed backend')

def run_allreduce(rank, size):
    data = torch.from_numpy(np.ones(size,dtype=np.float32))
    t0 = time.time()
    n = 10
    for i in range(n):
        dist.all_reduce(data, op=dist.reduce_op.SUM)
    t1 = time.time()
    print('average time:', (t1-t0)/n)

def run_sendrecv(rank, size):
    assert(dist.get_world_size() == 2)
    data = torch.from_numpy(np.zeros(size,dtype=np.float32))
    t0 = time.time()
    n = 10
    for i in range(n):
        if rank == 0:
            dist.send(data, 1)
        else:
            dist.recv(data, 0)
    t1 = time.time()
    print('average time:', (t1-t0)/n, 'BW:', size*4/(t1-t0)*n/1024/1024/1024,'GB/s')

args = parser.parse_args()

proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']

if '[' in node_list:
    node_list = node_list.split(',')[0].replace('[', '')
addr = node_list[8:].replace('-', '.')
addr = 'tcp://'+addr+':23456'
print(addr)

dist.init_process_group(backend=args.dist_backend, init_method=addr, world_size=ntasks, rank=proc_id)

run_allreduce(proc_id, 1024*1024*25)
#run_sendrecv(proc_id, 1024*1024*25)

The training scrip is modified from official imagenet example, since its a bit longer, I just show the concept. The DataParallel or DistributedDataParallel code are removed, instead i use following function after loss.backward to average gradients:

def average_gradients(model):
    """ Gradient averaging. """
    size = float(dist.get_world_size())
    for param in model.parameters():
        new_grad = param.grad.data.cpu()
        dist.all_reduce(new_grad, op=dist.reduce_op.SUM)
        param.grad.data = new_grad.cuda()
        param.grad.data /= size

As you pointed out, there are so many memory copies here.
It seems that there are much more works did in DistributedDataParallel, I will take a closer look.

Thanks for the help!

Is there any experimental code for Gloo with IB support that i can have a try?

The new NCCL2 backend seems quite ready, I will try it and feedback soon.

Thanks again for your effort!

Hi teng-li,
I have tried your new NCCL2 backend, but getting the following error(I have tried NCCL2.0.5 and NCCL2.1, both raising the same error.):

NCCL error in: pytorch-nccl2/torch/lib/THD/base/data_channels/DataChannelNccl.cpp:330, unhandled cuda error

By setting NCCL_DEBUG=INFO, I get the following output:

INFO NET : Using interface ib1:10.10.120.68<0>
INFO NET/IB : Using interface ib1 for sideband communication
INFO NET/IB: [0] mlx4_0:2/IB
INFO Using internal Network IB
INFO NET : Using interface ib1:10.10.120.68<0>
INFO NET/IB : Using interface ib1 for sideband communication
INFO NET/IB: [0] mlx4_0:2/IB
INFO Using internal Network IB
NCCL version 2.0.5 compiled with CUDA 8.0
INFO CUDA Dev 0, IB Ports : mlx4_0/2(SOC)
INFO CUDA Dev 0, IB Ports : mlx4_0/2(SOC)
INFO Using 256 threads
INFO [0] Ring 0 :    0   1
INFO 1 -> 0 via P2P/IPC
INFO 1 -> 0 via P2P/IPC
INFO 0 -> 1 via P2P/IPC
INFO 0 -> 1 via P2P/IPC

transport/p2p.cu:429 WARN failed to open CUDA IPC handle : invalid argument
INFO transport/p2p.cu:439 -> 1
INFO init.cu:373 -> 1
INFO init.cu:432 -> 1
INFO
transport/p2p.cu:429 WARN misc/group.cu:87 -> 1 [Async thread]
failed to open CUDA IPC handle : invalid argument
INFO transport/p2p.cu:439 -> 1
INFO init.cu:373 -> 1
INFO init.cu:432 -> 1
INFO misc/group.cu:87 -> 1 [Async thread]

Any ideas?

Oh I get it!

It is caused by following line:

os.environ['CUDA_VISIBLE_DEVICES'] = str(proc_id%8)

I set this in each process to avoid allocating extra GPU memory on device 0(without this env variable set there will be 8 processes on GPU 0 and 7 of them each occupies around 300MB extra memory).

I remove the above line and add the following line:

torch.cuda.set_device(proc_id%8)

then everything runs smoothly!!!

On 8 titanxp, resnet50 with batchsize=32 runs 0.2s per iteration, far more better than the ‘DataParallel’ which is 0.3s per iteration, this improvement is more obvious for larger models. The cool thing is that NCCL2 automatically detects infiniband and directly extend to multi-node! When using 16 GPUs, it is 0.23s per iteration, considering that I’m using raw allreduce-based synchronized SGD without any other optimizations, this result is really good!

Thank you teng-li! Your NCCL2 backend works really well! Hope it will be merged to master soon.

4 Likes

@zjoe I just pushed a patch that improves DataChannelMPI in a way that should make it CUDA compatible and improve performance. Can you let me know what are the numbers if you build from this branch? https://github.com/pytorch/pytorch/pull/3817

Thanks!

Gloo with GPU direct IB support is still having some issues that we are currently working on. But you can always use IP over IB using Gloo by just binding the --dist-url to be the address of your IB card. IP over IB will work fine as well using Gloo.

1 Like

Thanks for trying it out and giving us the feedback on the NCCL2 backend. I am glad it works well for you. Feel free to reach out to us if you are seeing any further issues. Happy to help.

This error is typically caused by library version mismatch and I would check for that first when this happens.

Sorry for the late response, a little busy. Actually I have searched for ‘inplace allreduce’ after you pointed out and try it with cpu tensors. This one-line-change indeed brings some improvement, but I haven’t made it CUDA compatible. Thanks for your patch, I will try it and feedback as soon as possible!

@zjoe I accedentaly commited a few bugs in that PR. They’re already fixed in master, so please use master when building!

@apaszke I tried your patch with mvapich-2.2.2-cuda8.0, the time per iteration for resnet50-batchsize32 is reduced from 0.45s to 0.4s(on 8 GPUs). Unfortunately the processes stuck when I try to use more than one node, need further debugging. Anyway it brings some improvement!

That’s weird, did it work for you before? Are you just using DistributedDataParallel or is it your custom code (if yes, what collectives/distributed functions are you using)?

I remember that it doesn’t work before your patch(when running on multiple node). I use the custom code:

for param in model.parameters():
    dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)

Hi zjoe,
I noticed that you share the same interest with me. Recently I am struggling to find ibverbs support for distributed pytorch. I use Gloo as the backend and run a simple distributed Mnist test, modified from here. However, I find that all the packets are transmitted by tcp. Do you know how to use ibverbs rather than tcp?
Any help will be appreciated!

I recommend you try this PR, which has already been merged into master.
NCCL2 will automatically find and use ibverbs.

Hi, can u help me with question[Build from source with MVAPICH], I also want to build pytorch with MVAPICH, but meet some problems. Thank you!