Send/Recv is slower in NCCL than in Gloo

Hi!
I tried to replace gloo with NCCL in my code because PyTorch 1.11 starts to support send/recv in NCCL.

Because the official document recommends using NCCL for GPU training, I expected that NCCL is faster than Gloo. However, I found that send/recv is slower in NCCL than in Gloo.

Is this result to be expected?
Moreover, is there any way to make send/recv for NCCL faster?

Code:

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time

def run(rank, size, backend):
    total_time = 0.
    
    for _ in range(10):
        if rank == 0:
            tensor = torch.randn((100,100), device="cuda:1")
            time.sleep(5) # wait untill process 1 initializes the variable 'tensor'.
            
            t1_time = time.time()
            if backend == "nccl":
                dist.send(tensor=tensor, dst=1)
            else:
                dist.send(tensor=tensor.to("cpu"), dst=1)
                
            total_time += time.time() - t1_time
        else:
            if backend == "nccl":
                tensor = torch.randn((100,100), device="cuda:2")
                dist.recv(tensor=tensor, src=0)
            else:
                tensor = torch.randn((100,100), device="cpu")
                dist.recv(tensor=tensor, src=0)
                
    if rank == 0:
        print(f"{backend}: {total_time} sec")
        
def init_process(rank, size, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29501'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, backend)


if __name__ == "__main__":
    mp.set_start_method("spawn")
    
    for backend in ["nccl", "gloo"]:
        size = 2
        processes = []
        for rank in range(size):
            p = mp.Process(target=init_process, args=(rank, size, run, backend))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

Results:

nccl: 3.5865402221679688 sec
gloo: 0.00415349006652832 sec

Enviroment:

  • Python 3.10.4
  • cuda 11.4
  • PyTorch 1.11.0
  • GPUs: NVIDIA Quadro A6000 * 2

Best Regard!

1 Like

NCCL has an extremely slow warmup compared to GLOO. It does lazily so you will notice the slowness on the first call.

I can reproduce your results locally and the vast majority of the time is spent on the first call to send, after that performance between the two is what you’d expected.

1 Like

Also, he did not perform torch.synchronize when testing the time. cuda dist.send and dist.recv should use that.

1 Like

Here’s an updated example that does warmup and incorporates Maxwell’s point on waiting for cuda ops.

Note that while waiting for CUDA ops makes sense when benchmarking, it’s usually not necessary otherwise.

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time

def run(rank, size, backend):

    def run_once():
        if rank == 0:
            tensor = torch.randn((100,100), device="cuda:0")
            
            t1_time = time.time()
            if backend == "nccl":
                dist.send(tensor=tensor, dst=1)
                # Wait for all data to be sent
                torch.cuda.synchronize()
            else:
                dist.send(tensor=tensor.to("cpu"), dst=1)
                
            return time.time() - t1_time
        else:
            if backend == "nccl":
                tensor = torch.randn((100,100), device="cuda:1")
                dist.recv(tensor=tensor, src=0)
            else:
                tensor = torch.randn((100,100), device="cpu")
                dist.recv(tensor=tensor, src=0)
            return 0

    # execute a few rounds of warmup
    warmup_time = 0.
    for _ in range(2):
        warmup_time += run_once()
    # measure runtime
    benchmark_time = 0.
    for _ in range(10):
        benchmark_time += run_once()
 
    if rank == 0:
        print(f"{backend}: warmup: {warmup_time} sec benchmark time: {benchmark_time} sec")
        
def init_process(rank, size, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29501'
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, backend)


if __name__ == "__main__":
    mp.set_start_method("spawn")
    
    for backend in ["nccl", "gloo"]:
        size = 2
        processes = []
        for rank in range(size):
            p = mp.Process(target=init_process, args=(rank, size, run, backend))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

On my system I get the following output:

nccl: warmup: 1.2995800971984863 sec benchmark time: 0.00042319297790527344 sec
gloo: warmup: 0.0008029937744140625 sec benchmark time: 0.0012202262878417969 sec
2 Likes

Thank you for your replies.
I can reproduce the results in my environment and confirm that send/recv in NCCL is faster than in GLOO except for the first communication.