Unexpected hang up when using DistributedDataParallel on two machines

Hi everyone. I am working on training models across multiple machines. Following the instruction from the documents, I write following codes:

  • On machine 1
import torch

torch.distributed.init_process_group(backend='nccl', world_size=2, rank=0, init_method='tcp://172.16.0.246:3456')

net = torch.nn.Linear(256, 128).cuda()
net = torch.nn.parallel.DistributedDataParallel(net, [0], 0)
  • On machine 2
import torch

torch.distributed.init_process_group(backend='nccl', world_size=2, rank=1, init_method='tcp://172.16.0.246:3456')

net = torch.nn.Linear(256, 128).cuda()
net = torch.nn.parallel.DistributedDataParallel(net, [0], 0)

In which 172.16.0.246 is the IP of machine 1. However, the code hang up unexpectedly when calling function _distributed_broadcast_coalesced in the initialization of DistributedDataParallel.

Is there anyone knows what did I do wrong?

Hey @IcarusWizard, what error did you see? Does gloo backend work for you?

Can you run the following command to check if the hostname can resolve to the expected IP on both machines?

getent hosts `hostname`

If the resolved IP is wrong, you can set NCCL_SOCKET_IFNAME env var to point to the right nic (e.g., eth0).

Hey Shen!

The strange thing is that there is actually no error. The code stuck at _distributed_broadcast_coalesced and cannot be terminated by Ctrl+C.

I have tried gloo, and it works smoothly which may suggest it is not an issue related to firewall. I also have set GLOO_SOCKET_IFNAME and NCCL_SOCKET_IFNAME to the correct interface on both machine.

And for the command you suggested, it returns 127.0.1.1 icarus-Polixir on machine 1, and

fe80::b62e:99ff:fe72:d1a1 polixir-G291-Z20-00
fe80::98a3:19ff:fe05:3c61 polixir-G291-Z20-00
fe80::42:adff:fe62:bb24 polixir-G291-Z20-00
fe80::d01c:dff:fe28:8b6f polixir-G291-Z20-00
fe80::1c3d:c8ff:fe62:76cc polixir-G291-Z20-00

on machine 2. I don’t know if it’s related to the issue.

If NCCL_SOCKET_IFNAME points to the correct interface, it should be fine even if the hostname resolves to wrong address, as the latter is a fallback of the former.

And as it has already reached the broadcast op in DDP, I would assume the rendezvous in init_process_group was successful. Could you please confirm this by adding the following code right after init_process_group and see if it also hang at this allreduce?

print("rendezvous done")
tmp = torch.ones(2, 2)
torch.distributed.all_reduce(tmp)
print(tmp)

Another thing that we could try is to set the following env vars and see if there is any NCCL logs that stand out.

export NCCL_DEBUG=INFO
export NCCL_DEBUG_SUBSYS=ALL

It stuck on the all_reduce operation too. I get some new information from NCCL logs.

On machine 1:

rendezvous done
icarus-Polixir:637574:637574 [0] NCCL INFO Bootstrap : Using [0]wlp0s20f3:172.16.0.246<0>
icarus-Polixir:637574:637574 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so).

icarus-Polixir:637574:637574 [0] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
icarus-Polixir:637574:637574 [0] NCCL INFO NET/Socket : Using [0]wlp0s20f3:172.16.0.246<0>
NCCL version 2.4.8+cuda10.2
icarus-Polixir:637574:637587 [0] NCCL INFO Setting affinity for GPU 0 to ff
icarus-Polixir:637574:637587 [0] NCCL INFO CUDA Dev 0[0], Socket NIC distance :  PHB
icarus-Polixir:637574:637587 [0] NCCL INFO Channel 00 :    0   1
icarus-Polixir:637574:637587 [0] NCCL INFO NET/Socket : GPU Direct RDMA Disabled for GPU 0[0] / HCA 0 (distance 2 >= 2)
icarus-Polixir:637574:637587 [0] NCCL INFO Ring 00 : 1 -> 0 [receive] via NET/Socket/0
icarus-Polixir:637574:637587 [0] NCCL INFO NET/Socket: Using 1 threads and 1 sockets per thread
icarus-Polixir:637574:637587 [0] NCCL INFO Ring 00 : 0 -> 1 [send] via NET/Socket/0

On machine 2:

rendezvous done
polixir-G291-Z20-00:2672:2672 [0] NCCL INFO Bootstrap : Using [0]enp129s0f1:172.16.16.122<0>
polixir-G291-Z20-00:2672:2672 [0] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so).
polixir-G291-Z20-00:2672:2672 [0] NCCL INFO NET/IB : Using [0]mlx5_1:1/RoCE ; OOB enp129s0f1:172.16.16.122<0>
polixir-G291-Z20-00:2672:2763 [0] NCCL INFO Setting affinity for GPU 0 to ffff0000,00000000,ffff0000,00000000
polixir-G291-Z20-00:2672:2763 [0] NCCL INFO CUDA Dev 0[4], IB NIC distance :  SYS
polixir-G291-Z20-00:2672:2763 [0] NCCL INFO NET/IB : GPU Direct RDMA Disabled for GPU 0[4] / HCA 0 (distance 4 >= 2)
polixir-G291-Z20-00:2672:2763 [0] NCCL INFO Ring 00 : 0 -> 1 [receive] via NET/IB/0
polixir-G291-Z20-00:2672:2763 [0] NCCL INFO Ring 00 : 1 -> 0 [send] via NET/IB/0
polixir-G291-Z20-00:2672:2763 [0] NCCL INFO NET/IB: Dev 0 Port 1 qpn 2358 mtu 3 GID 0 (80FE/A1D172FEFF992EB6)

First thing I noticed is that it trying to find libnccl-net.so. However I cannot find this file in official release of NCCL. I have tried to run this two scripts both on machine 1, and it works just fine. Thus I think it may not be the source of problem.

Any other thoughts?

Did you install NCCL yourself or the one bundled within PyTorch?

Not sure if I interpreted the logs correctly, but looks like machine 1 is trying to use TCP while machine 2 is trying to use IB? What if you set NCCL_IB_DISABLE on both machine? https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-ib-disable

2 Likes

IT WORKS after disabling IB. Seems like IB is a hardware related feature, and the web interface on machine 1 simply doesn’t support it.

I am using the NCCL bundled with PyTorch. I have also tried to install NCCL myself, but it makes no different.

Thanks a lot for your help!

2 Likes