Pytorch NCCL DDP freezes but Gloo Works

I am trying to get NCCL backend working on my Ubuntu 20.04 system that has two Nvidia 2070S GPUs and runs Pytorch 1.8.

My test script is based on the Pytorch docs, but with the backend changed from "gloo" to "nccl".

When the backend is "gloo", the script finishes running in less than a minute.

$ time python test_ddp.py 
Running basic DDP example on rank 0.
Running basic DDP example on rank 1.

real    0m4.839s
user    0m4.980s
sys     0m1.942s

However, when the backend is set to "nccl", the script gets stuck with the below output and never returns to the bash prompt.

$ python test_ddp.py 
Running basic DDP example on rank 1.
Running basic DDP example on rank 0.

Same problem when disabling IB

$ NCCL_IB_DISABLE=1 python test_ddp.py
Running basic DDP example on rank 1.
Running basic DDP example on rank 0.

I’m using the packages:

  • pytorch 1.8.1
  • cudatoolkit 11.1.1
  • python 3.8.8

How can we fix the problem when using NCCL? Thank you!

Python code used for testing NCCL:

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp

from torch.nn.parallel import DistributedDataParallel as DDP


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    # gloo: works
    # dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # nccl: hangs forever
    dist.init_process_group(
        "nccl", init_method="tcp://10.1.1.20:23456", rank=rank, world_size=world_size
    )


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def demo_basic(rank, world_size):
    print(f"Running basic DDP example on rank {rank}.")
    setup(rank, world_size)

    # create model and move it to GPU with id rank
    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn(outputs, labels).backward()
    optimizer.step()

    cleanup()


def run_demo(demo_fn, world_size):
    mp.spawn(demo_fn, args=(world_size,), nprocs=world_size, join=True)


if __name__ == "__main__":
    run_demo(demo_basic, 2)

Thanks for posting your code and details!

I see that when you init the process group with NCCL you specify the init_method as “tcp” and provide a local IP and port. Can you ensure that those are reachable? Alternatively, since you are already setting the address and port as environment variables and it works for gloo, you can remove the “init_method” parameter from init_process_group and it will default to use “env://” and that should work as well.

dist.init_process_group("nccl", rank=rank, world_size=world_size)

Here is the documentation for init_process_group: Distributed communication package - torch.distributed — PyTorch 1.8.1 documentation

Please let me know if this works.

Yes, you are right! I’ve got it running with NCCL by changing setup function as suggested

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

Yes, you are right! I’ve got it running with NCCL by changing setup function as suggested

def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

However, is this now using TCP initialization?

Yes, the underlying communication uses TCP. During initialization it uses the variables you defined for addr, port, rank, and world_size to create TCPStore instances on all workers.
https://pytorch.org/docs/master/distributed.html#torch.distributed.Store