DistributedDataParallel sub-linear scaling (~85-90% of linear with 2 GPUs)

I am having issues getting DistributedDataParallel to perform well (2 GPUs on the same host perform at ~85-90% of linear scaling, and it gets worse as GPUs or hosts are added). From slack, it seems other users are able to get much closer to 99% of linear with small numbers of nodes/GPUs.

I’m seeing this 85-90% scaling behavior on the (shared) work cluster, and on a 2 GPU system I have at home. I haven’t tested the full cross product, but I’ve seen the same behavior on Ubuntu 14.04 and 18.04; CUDA 9.1, 10.0, and 10.2; stock PyTorch 1.4 DDP and NVIDIA Apex DDP; resnet 50, 152, and some toy models. All used fake data from torchvision with batch sizes that use up the majority of GPU RAM.

The training script is here (with light edits to remove comments, etc.): https://gist.github.com/elistevens/7edacdafdb45747a22da2ef0c6ce1af3
OMP_NUM_THREADS=4 EPOCHS=2 EPOCH_SIZE=3840 BATCH_SIZE=64 NODES=1 GPUS=2 ~/v/bin/python min_ddp.py etc.

The numbers here are from my 18.04 home system with 2x 1080 Tis. There’s roughly a three-second slowdown for the 2 GPU case, resulting in training going from 22 seconds (1 GPU, 1 epoch) to 25 seconds (2 GPUs, 2 epochs). About a second and a half of that is the {method 'acquire' of '_thread.lock' objects} and the rest seems to be mul_, add_ etc. methods of torch._C._TensorBase objects.

Is this expected? Am I missing something that would cause performance to be poor like this?

Thanks for any help. More detailed data is below.

1 GPU
         308413 function calls (297131 primitive calls) in 22.053 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       60    5.305    0.088    5.305    0.088 {method 'run_backward' of 'torch._C._EngineBase' objects}
    19320    3.509    0.000    3.509    0.000 {method 'mul_' of 'torch._C._TensorBase' objects}
    19320    3.372    0.000    3.372    0.000 {method 'add_' of 'torch._C._TensorBase' objects}
     9660    2.298    0.000    2.298    0.000 {method 'addcdiv_' of 'torch._C._TensorBase' objects}
     9660    2.124    0.000    2.124    0.000 {method 'sqrt' of 'torch._C._TensorBase' objects}
       60    1.741    0.029   14.598    0.243 /home/elis/v/lib/python3.6/site-packages/torch/optim/adam.py:49(step)
     9660    1.499    0.000    1.499    0.000 {method 'addcmul_' of 'torch._C._TensorBase' objects}
      224    0.671    0.003    0.671    0.003 {method 'acquire' of '_thread.lock' objects}
      120    0.548    0.005    0.548    0.005 {method 'to' of 'torch._C._TensorBase' objects}
     3180    0.141    0.000    0.141    0.000 {built-in method conv2d}
...

2 GPUs
         312342 function calls (301058 primitive calls) in 25.171 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       60    5.355    0.089    5.355    0.089 {method 'run_backward' of 'torch._C._EngineBase' objects}
    19320    4.015    0.000    4.015    0.000 {method 'mul_' of 'torch._C._TensorBase' objects}
    19320    3.668    0.000    3.668    0.000 {method 'add_' of 'torch._C._TensorBase' objects}
     9660    2.407    0.000    2.407    0.000 {method 'sqrt' of 'torch._C._TensorBase' objects}
     9660    2.339    0.000    2.339    0.000 {method 'addcdiv_' of 'torch._C._TensorBase' objects}
      264    2.089    0.008    2.089    0.008 {method 'acquire' of '_thread.lock' objects}
       60    1.800    0.030   15.833    0.264 /home/elis/v/lib/python3.6/site-packages/torch/optim/adam.py:49(step)
     9660    1.566    0.000    1.566    0.000 {method 'addcmul_' of 'torch._C._TensorBase' objects}
      120    0.561    0.005    0.561    0.005 {method 'to' of 'torch._C._TensorBase' objects}
      105    0.275    0.003    0.275    0.003 {built-in method posix.waitpid}
     3180    0.252    0.000    0.252    0.000 {built-in method conv2d}
...

g2-g1    function
delta  
1.418,   {method 'acquire' of '_thread.lock' objects}
0.506,   {method 'mul_' of 'torch._C._TensorBase' objects}
0.296,   {method 'add_' of 'torch._C._TensorBase' objects}
0.283,   {method 'sqrt' of 'torch._C._TensorBase' objects}
0.184,   {built-in method posix.waitpid}
0.111,    {built-in method conv2d}
0.067,   {method 'addcmul_' of 'torch._C._TensorBase' objects}
0.059,   /home/elis/v/lib/python3.6/site-packages/torch/optim/adam.py:49(step)
0.05,    {method 'run_backward' of 'torch._C._EngineBase' objects}
0.049,   {built-in method _posixsubprocess.fork_exec}
0.041,   {method 'addcdiv_' of 'torch._C._TensorBase' objects}
0.037,   {built-in method relu_}
0.023,   {built-in method batch_norm}
0.015,   {built-in method max_pool2d}
0.013,   {method 'to' of 'torch._C._TensorBase' objects}
0.008,   {built-in method torch.distributed._broadcast_coalesced}

Hey @elistevens

Looks like you already using DDP with one device per processes, which is the recommended setup.

Can you try different OMP_NUM_THREADS configurations? Does it speed up or slow down if you set OMP_NUM_THREADS to 1? Sometimes DataLoader can also cause slowdowns. If does it affect the performance if you get rid of the DataLoader by using synthetic generated input/output (just for testing purpose)?

Yes, I’m using one process per GPU.

OMP_NUM_THREADS at 4 doesn’t have much of a difference from 1; leaving it unset has a very slight performance regression.

I am already using synthetic data using the torchvision.datasets.FakeData; each data loader process takes up about 15% CPU with 4 procs, and 60% CPU with one process. The overall scaling jumps to 92% of linear using 1 worker process, but drops to ~65% with num_workers set to zero (so all of the data stuff happens in the main process).

If I get rid of the DataLoader entirely, and just do:

x = torch.rand((batch_size, 3, 224, 224), device='cuda:' + str(gpu_ndx))
y = torch.randint(0, 100, size=(batch_size,), dtype=torch.long, device='cuda:' + str(gpu_ndx))

Inside the training loop, then I see a 6% performance improvement in the single-GPU case, and the two-GPU case jumps to 94% of linear (based off of the improved single GPU perf). The primary causes of slowdown are now basic math methods of torch._C._TensorBase:

0.498, {method 'mul_' of 'torch._C._TensorBase' objects}
0.254, {method 'add_' of 'torch._C._TensorBase' objects}
0.176, {method 'sqrt' of 'torch._C._TensorBase' objects}
0.153, {method 'addcdiv_' of 'torch._C._TensorBase' objects}
0.07, {method 'run_backward' of 'torch._C._EngineBase' objects}
0.064, /home/elis/v/lib/python3.6/site-packages/torch/optim/adam.py:49(step)
0.03, {built-in method conv2d}
0.022, {method 'addcmul_' of 'torch._C._TensorBase' objects}
0.012, {built-in method batch_norm}
0.011, {built-in method zeros_like}

The 94% of linear scaling remains the case even if I move the creation of x and y outside the training loop. Switching from Adam to SGD speeds things up marginally, but doesn’t change the ratio.

1 Like

I tried to remove all the dataloader overhead and profiling overhead and see about 98% scaling:

$ OMP_NUM_THREADS=1 EPOCHS=1 EPOCH_SIZE=3840 BATCH_SIZE=64 NODES=1 GPUS=1 python /tmp/min_ddp.py
2020-04-21 14:51:31.024092 torch.cuda.set_device(0); torch.distributed.init_process_group('nccl', rank=0, world_size=1)
2020-04-21 14:51:33.167023 Epoch 1, dl: 60
2020-04-21 14:52:08.255089 training loop time: 35.08807826042175 seconds
$ OMP_NUM_THREADS=1 EPOCHS=2 EPOCH_SIZE=3840 BATCH_SIZE=64 NODES=1 GPUS=2 python /tmp/min_ddp.py
2020-04-21 14:52:15.271820 torch.cuda.set_device(1); torch.distributed.init_process_group('nccl', rank=1, world_size=2)
2020-04-21 14:52:15.278892 torch.cuda.set_device(0); torch.distributed.init_process_group('nccl', rank=0, world_size=2)
2020-04-21 14:52:18.939304 Epoch 1, dl: 30
2020-04-21 14:52:18.939501 Epoch 1, dl: 30
2020-04-21 14:52:37.220102 Epoch 2, dl: 30
2020-04-21 14:52:37.220168 Epoch 2, dl: 30
2020-04-21 14:52:54.701512 training loop time: 35.76222109794617 seconds

Code changes:

import datetime
import math
import os
import time

import torch
import torch.distributed
import torch.multiprocessing
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel

from torch.nn.parallel import DistributedDataParallel
#from apex.parallel import DistributedDataParallel

import torchvision


num_nodes = int(os.environ['NODES'])
num_gpus = int(os.environ['GPUS'])


def main(ddp_wrapper=None, sampler_cls=None, gpu_ndx=0):
    epoch_size = int(os.environ['EPOCH_SIZE'])
    ds = torchvision.datasets.FakeData(
        epoch_size,
        num_classes=100,
        transform=torchvision.transforms.ToTensor(),
    )

    dl = DataLoader(
        ds,
        batch_size=int(os.environ['BATCH_SIZE']),
        num_workers=4,
        pin_memory=True,
        sampler=sampler_cls(ds) if sampler_cls else None,
    )

    model = torchvision.models.resnet50()
    model = model.to('cuda')

    if ddp_wrapper:
        model = ddp_wrapper(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    '''
    import cProfile, pstats, io
    pr = cProfile.Profile()
    pr.enable()
    '''

    batch_size = int(os.environ['BATCH_SIZE'])
    x = torch.rand((batch_size, 3, 224, 224), device='cuda:' + str(gpu_ndx))
    y = torch.randint(0, 100, size=(batch_size,), dtype=torch.long, device='cuda:' + str(gpu_ndx))
    start_ts = time.time()
    for epoch_ndx in range(1, int(os.environ['EPOCHS']) + 1):
        iters = int(epoch_size/batch_size/num_gpus)
        print(datetime.datetime.now(), f"Epoch {epoch_ndx}, dl: {iters}")
        for i in range(iters):
            optimizer.zero_grad()

            #x, y = batch_tup
            x = x.to('cuda')
            y = y.to('cuda')

            y_hat = model(x)
            loss_var = F.cross_entropy(y_hat, y)

            loss_var.backward()
            optimizer.step()
    end_ts = time.time()

    #pr.disable()

    if gpu_ndx == 0:
        '''
        pr.dump_stats('/tmp/min_profile.out')
#        pstats.Stats(pr).sort_stats('cumulative').print_stats()
        pstats.Stats(pr).sort_stats('tot').print_stats()
        '''

        print(datetime.datetime.now(), f"training loop time: {end_ts - start_ts} seconds")
        '''
        print('\n'.join(
            ['min ddp', 'cluster']
            + [os.environ[x] for x in ['NODES', 'GPUS', 'BATCH_SIZE', 'EPOCH_SIZE', 'EPOCHS', 'OMP_NUM_THREADS']]
            + [f'{end_ts - start_ts}']
            + [f"{int(os.environ['EPOCH_SIZE']) * int(os.environ['EPOCHS']) / (end_ts - start_ts) / int(os.environ['GPUS'])}"]
            + [f"{int(os.environ['EPOCH_SIZE']) * int(os.environ['EPOCHS']) / (end_ts - start_ts) / int(os.environ['GPUS']) / 1.737005}"]
        ))
        '''


def ddp_spawn(gpu_ndx):
    node_rank = 0
    rank = num_gpus * node_rank + gpu_ndx
    world_size = num_nodes * num_gpus

    print(datetime.datetime.now(), f"torch.cuda.set_device({gpu_ndx}); torch.distributed.init_process_group('nccl', rank={rank}, world_size={world_size})")

    torch.cuda.set_device(gpu_ndx)
    torch.distributed.init_process_group('nccl', rank=rank, world_size=world_size)

    main(
        ddp_wrapper=lambda m: DistributedDataParallel(m, [gpu_ndx]),
        sampler_cls=torch.utils.data.distributed.DistributedSampler,
        gpu_ndx=gpu_ndx,
    )


if __name__ == '__main__':
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '1234'

    torch.multiprocessing.spawn(ddp_spawn, nprocs=num_gpus, args=())
1 Like

That’s odd; I see about 95% using the code you posted.

$ OMP_NUM_THREADS=1 NUM_WORKERS=1 EPOCHS=2 EPOCH_SIZE=3840 BATCH_SIZE=64 NODES=1 GPUS=2 ~/v/bin/python forum_min_ddp.py
2020-04-21 15:18:35.059667 torch.cuda.set_device(1); torch.distributed.init_process_group('nccl', rank=1, world_size=2)
2020-04-21 15:18:35.059667 torch.cuda.set_device(0); torch.distributed.init_process_group('nccl', rank=0, world_size=2)
2020-04-21 15:18:36.918569 Epoch 1, dl: 30
2020-04-21 15:18:36.918801 Epoch 1, dl: 30
2020-04-21 15:18:48.006334 Epoch 2, dl: 30
2020-04-21 15:18:48.007357 Epoch 2, dl: 30
2020-04-21 15:18:59.082558 training loop time: 22.16376233100891 seconds

$ OMP_NUM_THREADS=1 NUM_WORKERS=1 EPOCHS=1 EPOCH_SIZE=3840 BATCH_SIZE=64 NODES=1 GPUS=1 ~/v/bin/python forum_min_ddp.py
2020-04-21 15:19:09.327511 torch.cuda.set_device(0); torch.distributed.init_process_group('nccl', rank=0, world_size=1)
2020-04-21 15:19:11.018629 Epoch 1, dl: 60
2020-04-21 15:19:31.962914 training loop time: 20.944292783737183 seconds

Could you post your output from nvidia-smi -q?

I should’ve mentioned I was running on master and not 1.4, although I’m not sure if that matters.

nvidia-smi -q seems to include some sensitive information like serial numbers and UUID, was there something specific you’d like to know about my setup? Happy to share that information.

Ahh sorry, I didn’t realize there was potentially sensitive info in there. I was mostly going to visually diff it with what I have here and see if anything jumped out at me.

For example, here is what I see when I’m actually running training: https://gist.github.com/elistevens/dbe5564873a1f55c4ac98594cfd31c63

This is my home system; it’s two 1080 Tis running on PCIe 3.0 8x slots (it’s an older consumer motherboard).

Hi Eli,
TL;DR: I suspect the main reason for the disparity you’re observing is that each epoch, the dataloader processes are shutdown and recreated, and that is not free.

I was able to reproduce your issue with your training script, although I had about 95% scaling from the start on my machine (Ubuntu 18.04.2, 28-core Intel Core i9-9940X CPU @ 3.30GHz, 2x Quadro RTX 5000, PyTorch 1.3.1, CUDA 10.0, NCCL 2.4.8) with the parameters

OMP_NUM_THREADS=4 EPOCHS=2 EPOCH_SIZE=3840 BATCH_SIZE=64 NODES=1 GPUS=2

vs.

OMP_NUM_THREADS=4 EPOCHS=1 EPOCH_SIZE=3840 BATCH_SIZE=64 NODES=1 GPUS=1

Looking at the GPU utilization with nvtop, I noticed a dip in GPU usage between the epochs with GPUS=2. I knew that dataloader processes are destroyed and then recreated from scratch every epoch, so GPUS=1 EPOCHS=1 version would only do it once, while GPUS=2 EPOCHS=2 would have to do it twice. So I decided to remove this unfairness: I made the script to always do just 1 epoch, and instead scale the dataset size like:

ds = torchvision.datasets.FakeData(
    int(os.environ['EPOCH_SIZE']) * int(os.environ['EPOCHS']),

This gave me 99% scaling, in fact even more when I set EPOCH_SIZE=38400 (10x). And, the invocations counts became equal between GPUS=1 and 2 for the top 20 functions from your pstats output. (Except the {method 'acquire' of '_thread.lock' objects} which has 4 extra invocations in GPUS=2 case. That one is coming from pin_memory thread, if you set pin_memory=False, all those acquires go away, but training, as expected, gets slower in both cases, although GPUS=1 case suffers more than GPUS=2).

BTW there is a way to not recreate dataloader processes each epoch and just loop over and over.

Hope this helps and you can replicate!

my pstats for GPUS=1
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      600   50.489    0.084   50.489    0.084 {method 'run_backward' of 'torch._C._EngineBase' objects}
   193200   37.203    0.000   37.203    0.000 {method 'mul_' of 'torch._C._TensorBase' objects}
   193200   32.657    0.000   32.657    0.000 {method 'add_' of 'torch._C._TensorBase' objects}
    96600   20.367    0.000   20.367    0.000 {method 'sqrt' of 'torch._C._TensorBase' objects}
    96600   19.994    0.000   19.994    0.000 {method 'addcdiv_' of 'torch._C._TensorBase' objects}
      600   15.513    0.026  140.393    0.234 /opt/conda/lib/python3.6/site-packages/torch/optim/adam.py:49(step)
    96600   14.561    0.000   14.561    0.000 {method 'addcmul_' of 'torch._C._TensorBase' objects}
     1200    3.360    0.003    3.360    0.003 {method 'to' of 'torch._C._TensorBase' objects}
    31800    1.367    0.000    1.367    0.000 {built-in method conv2d}
      600    1.046    0.002    1.046    0.002 {built-in method torch.distributed._broadcast_coalesced}
    31800    0.919    0.000    0.919    0.000 {built-in method batch_norm}
    31800    0.704    0.000    1.962    0.000 /opt/conda/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:58(forward)
     1840    0.496    0.000    0.496    0.000 {method 'acquire' of '_thread.lock' objects}
    96439    0.359    0.000    0.359    0.000 {method 'zero_' of 'torch._C._TensorBase' objects}
    29400    0.307    0.000    0.307    0.000 {built-in method relu_}
110400/600    0.254    0.000    5.682    0.009 /opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py:531(__call__)
     9600    0.223    0.000    4.182    0.000 /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:95(forward)
        5    0.219    0.044    0.219    0.044 {built-in method _posixsubprocess.fork_exec}
   353400    0.175    0.000    0.175    0.000 /opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py:571(__getattr__)
       53    0.116    0.002    0.116    0.002 {built-in method posix.waitpid}
    31800    0.096    0.000    1.057    0.000 /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1643(batch_norm)
      600    0.087    0.000    0.485    0.001 /opt/conda/lib/python3.6/site-packages/torch/optim/optimizer.py:159(zero_grad)
my pstats for GPUS=2
   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      600   50.391    0.084   50.391    0.084 {method 'run_backward' of 'torch._C._EngineBase' objects}
   193200   37.327    0.000   37.327    0.000 {method 'mul_' of 'torch._C._TensorBase' objects}
   193200   32.815    0.000   32.815    0.000 {method 'add_' of 'torch._C._TensorBase' objects}
    96600   20.583    0.000   20.583    0.000 {method 'sqrt' of 'torch._C._TensorBase' objects}
    96600   20.394    0.000   20.394    0.000 {method 'addcdiv_' of 'torch._C._TensorBase' objects}
      600   15.658    0.026  141.634    0.236 /opt/conda/lib/python3.6/site-packages/torch/optim/adam.py:49(step)
    96600   14.755    0.000   14.755    0.000 {method 'addcmul_' of 'torch._C._TensorBase' objects}
     1200    3.549    0.003    3.549    0.003 {method 'to' of 'torch._C._TensorBase' objects}
    31800    1.394    0.000    1.394    0.000 {built-in method conv2d}
      600    1.104    0.002    1.104    0.002 {built-in method torch.distributed._broadcast_coalesced}
    31800    0.924    0.000    0.924    0.000 {built-in method batch_norm}
    31800    0.713    0.000    1.975    0.000 /opt/conda/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py:58(forward)
     1844    0.550    0.000    0.550    0.000 {method 'acquire' of '_thread.lock' objects}
    96439    0.369    0.000    0.369    0.000 {method 'zero_' of 'torch._C._TensorBase' objects}
    29400    0.316    0.000    0.316    0.000 {built-in method relu_}
110400/600    0.261    0.000    5.870    0.010 /opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py:531(__call__)
     9600    0.242    0.000    4.253    0.000 /opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:95(forward)
        5    0.215    0.043    0.215    0.043 {built-in method _posixsubprocess.fork_exec}
   353400    0.174    0.000    0.174    0.000 /opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py:571(__getattr__)
       53    0.152    0.003    0.152    0.003 {built-in method posix.waitpid}
      600    0.118    0.000    0.118    0.000 {built-in method addmm}
    31800    0.098    0.000    1.063    0.000 /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py:1643(batch_norm)
      600    0.090    0.000    0.499    0.001 /opt/conda/lib/python3.6/site-packages/torch/optim/optimizer.py:159(zero_grad)
3 Likes

That’s a good point; I hadn’t considered the disparity introduced by having a different number of epochs. While fixing up my testing script, I stumbled across what I think is a key culprit: thermal throttling.

My home setup has the two GPUs in adjacent slots, and what I think is happening is that the airflow into the top GPU is being warmed by the backplate of the bottom GPU, because if I heat the GPUs up with a job the top hits 90C and the pclck from nvidia-smi dmon drops, but it drops more with a 2 GPU job.

The hint that clued me in was the 2-GPU times getting worse as I increased the epoch size, rather than better.

While I had tested on work systems, those earlier tests might have suffered from issues with epoch counts, etc. I’m going to rerun those tests with my updated testing script on work systems and see what the results are. I’ll report back when I have them (probably tomorrow).

Thank you to everyone who took the time to read, comment, and/or run my testing script. :slight_smile:

3 Likes

Short follow up: with the suggested changes, I was able to get scaling at 98.5% of linear with 2 GPUs on the work cluster. Thanks again!

3 Likes