Distributed Data Parallel slower than Data Parallel

Hi, there.

I have implemented a Cifar10 classifier using the Data Parallel of Pytorch, and then I changed the program to use the Distributed Data Parallel. I was surprised at that the program has become very slow. Using 8 GPUs (K80) with a batch size of 4096, the Distributed Data Parallel program spends 47 seconds to train a Resnet 34 model for one epoch, while the Data Parallel program took only 32 seconds.

I run the program on a cloud environment with 8 vCPU with 52GBytes of memory, and it does not seem to be a data transfer problem. So, I roughly measured the time spent for each task within the DP and DDP processes. The results are shown below.

DP
image

DDP
image

In the above screen shot, the left most number is value of loss and other numbers are execution time of each task in seconds. The “out” represents the forward path and “back” represents the backward. As you can see, DDP takes more than twice of the computation time compared to DP for both the forward and backward path. I do not understand why this happens.

I suppose that this post discusses the same issue, and it seems that the issue has been addressed.

However, it still happens in my program. The torch version of the program is 1.4.0. Should I update the version to solve the problem? Or, should I use Apex Distributed Data Parallel?

2 Likes

Does this mean each DDP process is consuming 4096 samples per iteration and the DP process is consuming 4096 * 8 = 32768 samples?

I suppose that this post discusses the same issue, and it seems that the issue has been addressed.

For the post you mentioned, it is only true for BERT models and it has been addressed in PyTorch v1.6.

BTW, how did you initialize DDP? Could you please share a repro?

Thank you for your reply.

Does this mean each DDP process is consuming 4096 samples per iteration and the DP process is consuming 4096 * 8 = 32768 samples?

No, I’m talking about the global batch size, which means a DDP process consumes 512 samples per iteration.

BTW, how did you initialize DDP? Could you please share a repro?

OK, I will try to provide a short version that can reproduce the performance shortly, since the original program is very long because of automation and visualization. Thank you.

1 Like

I share here two lists of codes to reproduce the issue. The first one is a program with DP, which is provided for comparison. The second one is with DDP, which takes longer for forward and backward path than DP.

DP

"""  Training Resnet34 for Cifar10 by Data Parallel """

from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

import sys
import time
import argparse

from models import *

from sync_batchnorm import convert_model, DataParallelWithCallback

def main() :
    
    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
    parser.add_argument('--net', default='res34')
    parser.add_argument('--batch_size', default=4096)
    parser.add_argument('--optimizer', default="Adam")
    parser.add_argument('--epochs', default=2)
    parser.add_argument('--n_nodes', default=1) 
    parser.add_argument('--nr', default=0)
    args = parser.parse_args()

    if torch.cuda.is_available() :
        args.n_gpus = torch.cuda.device_count()
        print(args.n_gpus, " GPU(s) available")
        print(torch.cuda.get_device_name(0))
        
    else :
        print("GPU is NOT available.")   
        sys.exit()
        
    print("Total batch size = ", args.batch_size)    
    print("Batch size = ", int(args.batch_size / args.n_gpus), "/ GPU")    
    print("Optimizer = ", args.optimizer)
    
    train(args)

    print()

       
# Training
def train(args):
    
    epochs = args.epochs
    batch_size = args.batch_size    # total batch_size.
    n_gpus = args.n_gpus
    
    worker = 8
      
    if args.net=='res18':
        net = ResNet18()
    elif args.net=='res34':
        net = ResNet34()
    elif args.net=='res50':
        net = ResNet50()
    elif args.net=='res101':
        net = ResNet101()
    
    print("Model = ", net.__class__.__name__)
    print()
    
    d_list = list(range(n_gpus))        
    net = convert_model(net).cuda() # Convert BatchNorm into SyncBatchNorm
    net = DataParallelWithCallback(net, device_ids = d_list) # Data Parallel
      
    cudnn.benchmark = True  
    
    criterion = nn.CrossEntropyLoss()
    
    if args.optimizer == "Adam" :
        optimizer = optim.Adam(net.parameters())
        
    elif args.optimizer == "SGD" :
        optimizer = optim.SGD(net.parameters(), lr = 0.1)
       
    transform_list = [
                  transforms.RandomChoice([
                  transforms.RandomCrop(32, padding=4),
                  transforms.RandomResizedCrop(32, scale=(0.7, 1.0), ratio = (1.0, 1.0)),
                  ]),
                  transforms.RandomHorizontalFlip(),
                  transforms.RandomRotation(degrees = 20), 
                  transforms.ToTensor(),
                  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                  ]
                  
    transform_train = transforms.Compose(transform_list)
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=worker)

    for epoch in range(epochs):
        
        print()
        print("epoch : ",epoch + 1, " / ", epochs)

        net.train()        
        
        """   ------- Training loop  -------- """
   
        for batch_idx, (inputs, targets) in enumerate(trainloader):
          
            inputs, targets = inputs.to('cuda'), targets.to('cuda')    
            
            message = ""
            t0 = time.time() 
            
            optimizer.zero_grad()      
            
            t1 = time.time() 
            message += "  zero grad: {0:.5f}".format(t1 - t0)
            
            outputs = net(inputs)
            
            t2 = time.time() 
            message += "  out: {0:.5f}".format(t2 - t1)
            
            loss = criterion(outputs, targets)
            
            t3 = time.time() 
            message += "  loss: {0:.5f}".format(t3 - t2)
            
            loss.backward()
            
            t4 = time.time() 
            message += "  back: {0:.5f}".format(t4 - t3)
            
            loss_val = optimizer.step(loss.item)  # loss value is given through optimizer.

            t5 = time.time() 
            message += "  step: {0:.5f}".format(t5 - t4)
                 
            print("{0:.6f}".format(loss_val) + message)                    


if __name__ == '__main__':
    main()
    

DDP

"""  Training Resnet34 for Cifar10 by Distributed Data Parallel """

from __future__ import print_function

import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

import sys
import os
import time
import argparse

from models import *

from sync_batchnorm import convert_model, DataParallelWithCallback

def main() :
    
    parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
    parser.add_argument('--net', default='res34')
    parser.add_argument('--batch_size', default=4096)
    parser.add_argument('--optimizer', default="Adam")
    parser.add_argument('--epochs', default=1)
    parser.add_argument('--n_nodes', default=1) 
    parser.add_argument('--nr', default=0)
    args = parser.parse_args()

    if torch.cuda.is_available() :
        args.n_gpus = torch.cuda.device_count()
        print(args.n_gpus, " GPU(s) available")
        print(torch.cuda.get_device_name(0))
        
        # for DDP
        args.world_size = args.n_gpus * args.n_nodes
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '8888' 
    
    else :
        print("GPU is NOT available.")   
        sys.exit()
        
    print("Total batch size = ", args.batch_size)
    
    args.batch_size = int(args.batch_size / args.world_size) # for DDP
    print("Batch size = ", args.batch_size, "/ GPU")
    
    print("Optimizer = ", args.optimizer)
    
    """ Distributed Data Parallel (DDP)"""
    mp.spawn(train, nprocs=args.n_gpus, args=(args,)) 

    print()

       
# Training
def train(gpu, args):
    
    rank = args.nr * args.n_gpus + gpu	                          
    dist.init_process_group(                                   
    	backend='nccl',                                         
   		init_method='env://',                                   
    	world_size=args.world_size,                              
    	rank=rank                                               
    )                                                          

    epochs = args.epochs
    batch_size = args.batch_size    # batch_size is per GPU size.
  
    torch.manual_seed(0)
    
    if args.net=='res18':
        net = ResNet18()
    elif args.net=='res34':
        net = ResNet34()
    elif args.net=='res50':
        net = ResNet50()
    elif args.net=='res101':
        net = ResNet101()
    
    if rank == 0 :         
        print("Model = ", net.__class__.__name__)
        print()

    
    torch.cuda.set_device(gpu)    
    
    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net = net.cuda(gpu)
    
    criterion = nn.CrossEntropyLoss().cuda(gpu)

    if args.optimizer == "Adam" :
        optimizer = optim.Adam(net.parameters())
        
    elif args.optimizer == "SGD" :
        optimizer = optim.SGD(net.parameters(), lr = 0.1)
 
    net = nn.parallel.DistributedDataParallel(net, device_ids=[gpu])

    transform_list = [
                  transforms.RandomChoice([
                  transforms.RandomCrop(32, padding=4),
                  transforms.RandomResizedCrop(32, scale=(0.7, 1.0), ratio = (1.0, 1.0)),
                  ]),
                  transforms.RandomHorizontalFlip(),
                  transforms.RandomRotation(degrees = 20), 
                  transforms.ToTensor(),
                  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                  ]
                  
    transform_train = transforms.Compose(transform_list)
     
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
    	trainset,
    	num_replicas = args.world_size,
    	rank = rank
    )

    trainloader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, 
                                              shuffle=False, num_workers=0,
                                              pin_memory = False, sampler=train_sampler)

    for epoch in range(epochs):
        
        if rank == 0 :
            print()
            print("epoch : ",epoch + 1, " / ", epochs)

        net.train()        
        
        """   ------- Training loop  -------- """
   
        for batch_idx, (inputs, targets) in enumerate(trainloader):
                      
            inputs = inputs.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)
            
            message = ""
            t0 = time.time() 
            
            optimizer.zero_grad()      
            
            t1 = time.time() 
            message += "  zero grad: {0:.5f}".format(t1 - t0)
            
            outputs = net(inputs)
            
            t2 = time.time() 
            message += "  out: {0:.5f}".format(t2 - t1)
            
            loss = criterion(outputs, targets)
            
            t3 = time.time() 
            message += "  loss: {0:.5f}".format(t3 - t2)
            
            loss.backward()
            
            t4 = time.time() 
            message += "  back: {0:.5f}".format(t4 - t3)
            
            loss_val = optimizer.step(loss.item)  # loss value is given through optimizer.

            t5 = time.time() 
            message += "  step: {0:.5f}".format(t5 - t4)
                 
            if rank == 0 :
                print("{0:.6f}".format(loss_val) + message)                    

        dist.destroy_process_group()


if __name__ == '__main__':
    main()
    

Please let me know if something is wrong. Thank you.

Hey @TT_YY, at a quick glance, I noticed that you are using time.time() to measure the time consumption. This does not work for CUDA ops, as they return immediately after the op inserted into the CUDA stream before they are actually done. You will need to create CUDA events and then use the elapsed_time API.

The code snippet in this comment can serve as an example. Search for torch.cuda.Event.

This does not work for CUDA ops, as they return immediately after the op inserted into the CUDA stream before they are actually done.

Ok, I didn’t know the detail that you explained, but I meant to do a rough estimate so I thought that’s good enough. Nevertheless, the result matches the feeling or actual time spent for one epoch of training, as you can see if you run it. The DDP program takes 47 sec. while DP takes 32 sec. in my environment.

1 Like

Hey @TT_YY, I took a closer look at the code and noticed that you converted BatchNorm to SyncBatchNorm for DDP, which might be the source of the slowness. If you look at SyncBatchNorm's implementation (see below), it launches its own communication, which is not handled by DDP. This additional comm leads to ~10% slowdown in your program when running on 2 GPUs. When I use BatchNorm instead of SyncBatchNorm, DDP is faster than DP. In general, when comparing DDP and DP speed, we need to make sure that they run the same model.

This is how I measure the latency.

# run one iteration to warm up
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
loss_val = optimizer.step(loss.item) 

# measure latency of the second iteration
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
loss_val = optimizer.step(loss.item)
end.record()
torch.cuda.synchronize()

print(f"world size = {args.world_size}, batch size = {batch_size}, latency = {start.elapsed_time(end)}")

I tried to run the DDP script with the following configs on two GPUs:

  1. Run as is

    world size = 2, batch size = 2048, latency = 506.9587707519531
    world size = 2, batch size = 2048, latency = 506.40606689453125
    
  2. Comment out the following line, as SyncBatchNorm has its own way to communicate buffers, which can e slower.

    #net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    
    world size = 2, batch size = 2048, latency = 456.42352294921875
    world size = 2, batch size = 2048, latency = 457.8104248046875
    
  3. Made the following edits and set args.n_gpus = 1. So the program runs DataParallel.

    #net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    ...
    #net = nn.parallel.DistributedDataParallel(net, device_ids=[gpu])
    net = nn.parallel.DataParallel(net)
    
    world size = 1, batch size = 4096, latency = 496.3483581542969
    
1 Like

Thank you for your detailed analysis.

In general, when comparing DDP and DP speed, we need to make sure that they run the same model.

I have converted BatchNorm into SyncBatchNorm in DP too, as you can find “convert_model()” in the above code list of DP.

As you pointed out, removing convert_model() from the DP program significantly improves the performance. (2500msec. / itr → 1500), However, I could not see such a difference of latency for the DDP program. The improvement observed in my experiment is less than 4%, with a batch size of 4096 and 8 GPUs. I used the same time measurement method as you did.

I can tolerate the 4% difference if I can make my DDP program faster than the DP.

DP with convert_model() -------------------------- DP without convert_model()
image

DDP with convert_sync_batchnorm() --------- DDP without convert_sync_batchnorm()
image

I use convert_model(), which converts BatchNorm into a SyncBatchNorm for DP, but it is different from the torch version for DDP. The torch.nn.SyncBatchNorm.convert_sync_batchnorm() supports only DDP.

By the way, I wonder why the latency in your experiment is one digit lower than mine. Are you using the same model (resnet34) and CIfar10?

If there is an example program for image classification using DDP, I will be curious to see the latency. I am trying to test my original optimizer with a practical settings. So far, the test comparing it with Adam has been successful in terms of the number of steps to reach a target accuracy. Now I have to improve the wall clock time and trying to find a way to scale the speed with the batch size, like the experiment you have shown. However, DDP is still not working in my program for that purpose now.

@mrshenli,

I have been trying to reproduce your results, but, for some reason, my experiment of DDP with two GPUs end up with about 4000 msec, unlike your results of about 500 msec.

I also tried DP with the same settings and the time was 1100 msec. The experiment indicates that DP is faster than DDP for a batch size of 2048 / GPU.

I used VGG11 model for the experiment, because K80 cannot accept 2048 data of Cifar10 with a large model like Resnet34 in its memory. I tried even a smaller model, but still DP is faster than DDP.

Would you please specify the model and data that you used to produce the results in your former post?
Are the GPUs that you used utilize NVLink?

Thank you for your cooperation.

As I don’t have a models package locally, I replaced that with torchvision.models, and then replaced all ResNet* with reset*. I am not sure if that’s also what you use.

Would you please specify the model and data that you used to produce the results in your former post?

I used the default model type in your code, so it should be torchvision.models.resnet34() after my replacement, and was using the same data loader.

Are the GPUs that you used utilize NVLink?

Yes, the comm should go through nvlink

        GPU0    GPU1    CPU Affinity
GPU0     X      NV4     12-23
GPU1    NV4      X      12-23
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.116.00   Driver Version: 418.116.00   CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|===============================+======================+======================|
|   0  Quadro GP100        Off  | 00000000:81:00.0 Off |                    0 |
| 26%   31C    P0    30W / 235W |     10MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  Quadro GP100        Off  | 00000000:82:00.0 Off |                    0 |
| 26%   31C    P0    30W / 235W |     10MiB / 16278MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+

@mrshenli,

Thank you for sharing your GPU spec. That’s what I should have asked in the first place. GP100 far exceeds K80 in almost all the aspects of the performance. In addition, in my platform, K80 does not use NVLink. (Maybe it originally does not support it.) I suppose that the source of difference is the spec difference in the performance of DDP.

On the other hand, I still don’t understand why DP is faster than DDP with the same GPUs in my environment. My guess is that the DP directly performing all-reduce by a high spec CPU can be faster than the DDP performing all-reduce by gpu0, which communicates through PCI bus and memory on board. But, I 'm not sure.

Thank you very much.

DP uses replicate, scatter, and gather operations, which are basically cudaMemcpy under the hood. I suspect cudaMemcpy can be faster than allreduce operations used by DDP on some hardware.

2 Likes

Good insight. I have learned a lot. Thank you.

Hi I am also facing a similar problem. Did u able to find the rootcause and solve the problem?