Hi there –
TL;DR: I’m running into some performance issues where the inter-process communication cost increases linearly in the number of processes that I spawn, and wanted to know what options are available to improve communication performance.
I’m using PyTorch for a distributed learning application that requires me to create a number of process groups that is ~linear in the number of processes that I’ve created, and I have to perform a decent amount of communication within those groups during the forward and backward passes through my model. Unfortunately, it seems like the cost of communication is increasing linearly in the number of processes I’m using for training, which is negating any performance benefits I’d otherwise be getting by scaling out training across multiple CPUs / GPUs.
Here’s a MWE showing how the communication costs are scaling for me:
import datetime
import os
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def target(process_group, n_iters = 1024):
for _ in range(n_iters):
x = torch.randn(10, 10)
dist.all_reduce(x, group=process_group)
def _launch(rank, world_size):
# Initialize process groups for every pair of consecutive processes
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("gloo", rank=rank, world_size=world_size)
if world_size % 2 != 0:
raise ValueError("The world size must be a multiple of 2")
for i in range(0, world_size, 2):
new_group = dist.new_group([i, i+1])
if rank == i or rank == i + 1:
process_group = new_group
start = datetime.datetime.now()
target(process_group)
end = datetime.datetime.now()
if rank == 0:
print(f"Time to run target() = {end - start}")
if __name__ == "__main__":
if len(sys.argv) != 2:
raise RuntimeError("Usage: ./test.py <n_process_groups>")
n_process_groups = int(sys.argv[1])
world_size = 2 * n_process_groups
mp.spawn(
_launch,
args=(world_size,),
nprocs=world_size,
join=True,
)
Here’s my results when I run the script:
$ for p in {1..8}; do echo -n "p = $p; " && python3 test.py $p; done
p = 1; Time to run target() = 0:00:00.386286
p = 2; Time to run target() = 0:00:00.478979
p = 3; Time to run target() = 0:00:00.728055
p = 4; Time to run target() = 0:00:00.920900
p = 5; Time to run target() = 0:00:01.150142
p = 6; Time to run target() = 0:00:01.505277
p = 7; Time to run target() = 0:00:01.937280
p = 8; Time to run target() = 0:00:02.068987
I’m not surprised that communication performance decreases as the number of processes increases, but I don’t understand why the cost is growing as quickly as it is.
So my question is – are the growing communication costs I’m seeing here just something that’s expected out of the gloo
backend or other elements of the PyTorch distributed communication internals? Or is there a way I could improve the script I have written above so that the target()
function runs faster?
Any help would be greatly appreciated. Thank you!