Improving inter-process communication costs for `pytorch.distributed`

Hi there –

TL;DR: I’m running into some performance issues where the inter-process communication cost increases linearly in the number of processes that I spawn, and wanted to know what options are available to improve communication performance.


I’m using PyTorch for a distributed learning application that requires me to create a number of process groups that is ~linear in the number of processes that I’ve created, and I have to perform a decent amount of communication within those groups during the forward and backward passes through my model. Unfortunately, it seems like the cost of communication is increasing linearly in the number of processes I’m using for training, which is negating any performance benefits I’d otherwise be getting by scaling out training across multiple CPUs / GPUs.

Here’s a MWE showing how the communication costs are scaling for me:

import datetime
import os
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

def target(process_group, n_iters = 1024):
    for _ in range(n_iters):
        x = torch.randn(10, 10)
        dist.all_reduce(x, group=process_group)

def _launch(rank, world_size):
    # Initialize process groups for every pair of consecutive processes

    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    if world_size % 2 != 0:
        raise ValueError("The world size must be a multiple of 2")

    for i in range(0, world_size, 2):
        new_group = dist.new_group([i, i+1])
        if rank == i or rank == i + 1:
            process_group = new_group

    start = datetime.datetime.now()
    target(process_group)
    end = datetime.datetime.now()

    if rank == 0:
        print(f"Time to run target() = {end - start}")

if __name__ == "__main__":
    if len(sys.argv) != 2:
        raise RuntimeError("Usage: ./test.py <n_process_groups>")

    n_process_groups = int(sys.argv[1])
    world_size = 2 * n_process_groups

    mp.spawn(
        _launch,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

Here’s my results when I run the script:

$ for p in {1..8}; do echo -n "p = $p; " && python3 test.py $p; done
p = 1; Time to run target() = 0:00:00.386286
p = 2; Time to run target() = 0:00:00.478979
p = 3; Time to run target() = 0:00:00.728055
p = 4; Time to run target() = 0:00:00.920900
p = 5; Time to run target() = 0:00:01.150142
p = 6; Time to run target() = 0:00:01.505277
p = 7; Time to run target() = 0:00:01.937280
p = 8; Time to run target() = 0:00:02.068987

I’m not surprised that communication performance decreases as the number of processes increases, but I don’t understand why the cost is growing as quickly as it is.

So my question is – are the growing communication costs I’m seeing here just something that’s expected out of the gloo backend or other elements of the PyTorch distributed communication internals? Or is there a way I could improve the script I have written above so that the target() function runs faster?

Any help would be greatly appreciated. Thank you!

(Incidentally, I’m currently doing all of my training on CPU. Here’s some additional info about the machine I’m currently running this on, if it helps:)

$ lscpu | grep -E "^(CPU\(s\):|Model name:)"
CPU(s):                48
Model name:            Intel(R) Xeon(R) Silver 4214 CPU @ 2.20GHz

Hi! Gloo backend is much slower than NCCL if can do trainings on GPU. You may want to switch to NCCL if possible.

A few relatively easy options on my mind:

  1. It will be much more efficient if you use DDP instead of implementing your own communication by allreduce primitive.
  2. Try no_sync context manager on the tutorial, which can reduce the sync frequency by accumulating gradients locally.
  3. If you can use NCCL, it will be much faster. You can also try DDP communication hooks for gradient compression.
1 Like

Thanks for the reply!

Unfortunately I can’t currently wrap my models with DDP (I’m integrating with another library that doesn’t play very well with DDP), but my plan is to do so as soon as possible.

That said, the application I’m developing for requires inter-process communication in both the forwards and backwards passes. The total communication is quite a bit higher than just the setup communication / gradient synchronization performed by DDP, so I suspect that wrapping my model with DDP won’t do very much to reduce communication costs.

I’d like to use NCCL, but it doesn’t directly support some of the communication primitives I need. In any case, for now I’m less concerned about raw communication speed and more concerned with how it scales as the number of processes increases.

If I used the NCCL backend, would the communication speed decrease at the same rate as Gloo (so that e.g. doubling the number of processes / process groups causes all_reduce to run at half speed)?

A few other thoughts:

Do you have a chance to set async_op=True in dist.all_reduce, so that you can overlap some computation with allreduce?

Additionally, probably you can try to tune some Gloo environment variables. I am not familiar with Gloo, but I see tuning some NCCL parameters can help communication performance a lot.

I’m integrating with another library that doesn’t play very well with DDP

Also interested in knowing the reasons that cause DDP not feasible for this case.

If I used the NCCL backend, would the communication speed decrease at the same rate as Gloo (so that e.g. doubling the number of processes / process groups causes all_reduce to run at half speed)?

NCCL will have both a better scaling efficiency and higher raw speed than Gloo.

I have tried this, yes! It gives a pretty good speedup for the test script (and a smaller speedup for my actual code), but not enough to counteract the growth of communications costs. There may be some other comms operations I could make async, though; I’ll definitely look into it.

I’d definitely like to know if there are any environmental variables I can tune for Gloo. I think I’ll need to dig in some more and see if anything like that exists. :slightly_smiling_face:

The library I’m using on top of PyTorch wraps a model and mostly tries to imitate PyTorch’s API for training. Unfortunately, the only API it currently exposes for updating parameters after backprop is a single function that basically just loops over the model parameters and applies a single step of SGD on them.

I decided after some experimentation that the amount of effort required to either (a) modify the library to be able to wrap DDP or (b) create a modified version of DDP that could wrap the library was nontrivial. I might end up having a flash of inspiration that’ll help me figure out how to do one or the other, but for now it’s not possible to use DDP for my purposes.

Thank you, that’s good to know. :slightly_smiling_face: If I can’t figure out how to speed up Gloo, I’ll try to see if I can modify my code to use NCCL.