Torch distributed not working on two machines [nccl backend]

Hi,

I am running a simple application on two machines with 2 gpus each, it is throwing me an error. The application works fine on a single machine with 2gpus.

The NCCL info error info in here

dml4:26072:26072 [1] NCCL INFO Bootstrap : Using [0]XXXXXX<0> [1]enp0s20f0u1u6:169.254.95.120<0> [2]virbr0:192.168.122.1<0>
dml4:26071:26071 [0] NCCL INFO Bootstrap : Using [0]XXXXX<0> [1]enp0s20f0u1u6:169.254.95.120<0> [2]virbr0:XXXXX<0>
dml4:26072:26072 [1] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so).
dml4:26071:26071 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so).
dml4:26072:26072 [1] NCCL INFO NET/IB : Using [0]mlx5_0:1/RoCE ; OOB enp88s0:9.1.44.100<0>
dml4:26071:26071 [0] NCCL INFO NET/IB : Using [0]mlx5_0:1/RoCE ; OOB enp88s0:9.1.44.100<0>
dml4:26072:26240 [1] NCCL INFO Setting affinity for GPU 1 to ffff,f00000ff,fff00000
dml4:26071:26242 [0] NCCL INFO Setting affinity for GPU 0 to 0fffff00,000fffff
dml4:26072:26240 [1] NCCL INFO CUDA Dev 1[1], IB NIC distance : SYS
dml4:26071:26242 [0] NCCL INFO CUDA Dev 0[0], IB NIC distance : NODE
dml4:26071:26242 [0] NCCL INFO Ring 00 : 1 -> 2 [receive] via NET/IB/0
dml4:26071:26242 [0] NCCL INFO Ring 00 : 2[0] -> 3[1] via direct shared memory
dml4:26072:26240 [1] NCCL INFO Ring 00 : 3 -> 0 [send] via NET/IB/0

dml4:26072:26240 [1] misc/ibvwrap.cc:252 NCCL WARN Call to ibv_reg_mr failed
dml4:26072:26240 [1] NCCL INFO transport/net_ib.cc:601 -> 2
dml4:26072:26240 [1] NCCL INFO include/net.h:24 -> 2
dml4:26072:26240 [1] NCCL INFO transport/net.cc:360 -> 2
dml4:26072:26240 [1] NCCL INFO init.cc:669 -> 2
dml4:26072:26240 [1] NCCL INFO init.cc:815 -> 2
dml4:26072:26240 [1] NCCL INFO init.cc:951 -> 2
dml4:26072:26240 [1] NCCL INFO misc/group.cc:69 -> 2 [Async thread]

dml4:26071:26242 [0] misc/ibvwrap.cc:252 NCCL WARN Call to ibv_reg_mr failed
dml4:26071:26242 [0] NCCL INFO transport/net_ib.cc:601 -> 2
dml4:26071:26242 [0] NCCL INFO include/net.h:24 -> 2
dml4:26071:26242 [0] NCCL INFO transport/net.cc:388 -> 2
dml4:26071:26242 [0] NCCL INFO init.cc:679 -> 2
dml4:26071:26242 [0] NCCL INFO init.cc:815 -> 2
dml4:26071:26242 [0] NCCL INFO init.cc:951 -> 2
dml4:26071:26242 [0] NCCL INFO misc/group.cc:69 -> 2 [Async thread]
Traceback (most recent call last):
File “conv_dist.py”, line 118, in
main()
File “conv_dist.py”, line 51, in main
mp.spawn(train, nprocs=args.gpus, args=(args,), join=True)
File “/work/tools/envs/dine2/lib/python3.6/site-packages/torch/multiprocessing/spawn.py”, line 200, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method=‘spawn’)
File “work/tools/envs/dine2/lib/python3.6/site-packages/torch/multiprocessing/spawn.py”, line 158, in start_processes
while not context.join():
File “/work/tools/envs/dine2/lib/python3.6/site-packages/torch/multiprocessing/spawn.py”, line 119, in join
raise Exception(msg)
Exception:

– Process 0 terminated with the following error:
Traceback (most recent call last):
File “/work/tools/envs/dine2/lib/python3.6/site-packages/torch/multiprocessing/spawn.py”, line 20, in _wrap
fn(i, *args)
File “/us4j4248/pt_dist/conv_dist.py”, line 75, in train
model = DDP(model, device_ids=[gpu])
File “/work/tools/envs/dine2/lib/python3.6/site-packages/torch/nn/parallel/distributed.py”, line 285, in init
self.broadcast_bucket_size)
File “/work/tools/envs/dine2/lib/python3.6/site-packages/torch/nn/parallel/distributed.py”, line 496, in _distributed_broadcast_coalesced
dist._broadcast_coalesced(self.process_group, tensors, buffer_size)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1591914838379/work/torch/lib/c10d/ProcessGroupNCCL.cpp:514, unhandled system error, NCCL version 2.4.8

ps: I have removed the ip addresses above.

Thanks

Hey @nash, could you please share a minimum repro of this error?

Hey Shen,
I am running a simple application using <torch.distributed.launch> on two machines each having 2 gpus. It throws me the above error.
CUDA 10.2
Pytorch 1.5.1
NCCL backend as
libnccl-devel-2.5.6-1+cuda10.2.x86_64
libnccl-2.5.6-1+cuda10.2.x86_64
libnccl-static-2.5.6-1+cuda10.2.x86_64

But torch.cuda.nccl.version() shows me 2.4.
Ran NCCL tests – they are working fine.

def train(args):
    current_env = os.environ.copy()
    dist.init_process_group(backend='nccl', init_method='env://')
    model = ConvNet()
    torch.cuda.set_device(args.local_rank)
    model.cuda(args.local_rank)
    batch_size = 256
    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.local_rank)
    optimizer = torch.optim.SGD(model.parameters(), 1e-4)

    model = DDP(model, device_ids=[args.local_rank])
    # Data loading code
    train_dataset = torchvision.datasets.MNIST(root='./data',
                                               train=True,
                                               transform=transforms.ToTensor(),
                                               download=True)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=int(current_env["WORLD_SIZE"]), rank=args.local_rank)

    train_loader = torch.utils.data.DataLoader(
    	dataset=train_dataset,
        batch_size=batch_size,
        shuffle=False,         
        num_workers=0,
        pin_memory=True,
        sampler=train_sampler)  

Plz let me know if you need any more info.
Thanks

Hey @nash, thanks for sharing the code. The code looks correct to me, except the rank argument in DistributedSampler might need to be global rank (i.e., current_env["RANK"]) instead of the local rank? But this is not the cause of the error, as the error was thrown in DDP ctor when it tries to run broadcast.

Regarding the error:

  1. Curious, does gloo backend work?
  2. Could you please print out the env vars set by the launching script? Sth as what this example does to env_dict.
  3. Could you try run the following command on both nodes? If the return IP is not what you intended, you might need to set either GLOO_SOCKET_IFNAME or NCCL_SOCKET_IFNAME as mentioned here depending on which backend you are using.
    getent hosts `hostname`
    

Hey @mrshenli thanks for your response.

  1. Gloo backend works perfectly fine with this code. NCCL backend throws the error.

2

here are the outputs of environ based on Gloo backend.
Node 0 output
Initializing process group with: {'MASTER_ADDR': 'a.b.c.d', 'MASTER_PORT': '12354', 'RANK': '1', 'WORLD_SIZE': '4'}
[3196] Initializing process group with: {'MASTER_ADDR': 'a.b.c.d', 'MASTER_PORT': '12354', 'RANK': '0', 'WORLD_SIZE': '4'}
[3196] world_size = 4, rank = 0, backend=gloo
[3197] world_size = 4, rank = 1, backend=gloo
Node 1 output
Initializing process group with: {'MASTER_ADDR': 'a.b.c.d', 'MASTER_PORT': '12354', 'RANK': '2', 'WORLD_SIZE': '4'}
[89966] Initializing process group with: {'MASTER_ADDR': 'a.b.c.d', 'MASTER_PORT': '12354', 'RANK': '3', 'WORLD_SIZE': '4'}
[89966] world_size = 4, rank = 3, backend=gloo
[89965] world_size = 4, rank = 2, backend=gloo
  1. the return is the domain name addresses of each host (a.b.c.d) etc. I am not setting anywhere SOCKET_IFNAME

My system has NCCL 2.5 while torch (torch.cuda.nccl.version()) shows 2.4.8
Could this be the problem?
How can I upgrade NCCL version in torch.

Thanks.

If Gloo works fine then it means all the env vars and configs should be correct.

How can I upgrade NCCL version in torch.

That will require modify pytorch NCCL submodule and recompile. Like this updating nccl to 2.7.3 by agolynski · Pull Request #40622 · pytorch/pytorch · GitHub. You can pull this PR can compile from it, which should be using NCCL 2.7.3.

Another option is to set export USE_SYSTEM_NCCL=1, and then compile from source, then it should be using the 2.5 that you installed on the machine.

Thanks @mrshenli.
As you mentioned that pytorch has NCCL precompiled and both nodes use the same version of NCCL.
Does that mean NCCL version is not the problem?

Did you notice this “misc/ibvwrap.cc:252 NCCL WARN Call to ibv_reg_mr failed” in the logs.

I tried to build torch from source, I hit another roadblock there as well.
“Performing Test SUPPORT_GLIBCXX_USE_C99 - Failed”

thanks.

An error in ibv (i.e., InfiniBand verbs) indicates problems with GPU Direct, which NCCL tries to use for RDMA but which Gloo doesn’t. You can try to confirm that this is indeed the issue by running with the NCCL_IB_DISABLE=1 env var. That may work but would probably end up being slower. In that case you might want to follow the instructions here to troubleshoot InfiniBand issues: https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#gpu-direct

3 Likes

Hi @lcw, thanks. Indeed it is NCCL issue, and setting NCCL_IB_DISABLE =1 works fine.