DistributedDataParallel slower than DataParallel?

Hello,

I am trying to use DDP to speed up the training of my model. I was originally using DP for the model training, but I’ve read here (Getting Started with Distributed Data Parallel — PyTorch Tutorials 1.11.0+cu102 documentation) that DDP is faster so I decided to switch to that. Weirdly enough, the training was slower using DDP vs using DP… I know something is wrong somewhere but I can’t seem to figure it out, so I tested both DP and DDP on a dummy model for comparison and got similar results… Below are implementations of the dummy model and the DP and DDP implementations (based on Optional: Data Parallelism — PyTorch Tutorials 1.11.0+cu102 documentation). What am I doing wrong? Any pointer/advice would be greatly appreciated. This is my first time posting here, so if there’s any additional info that I forgot to add please do not hesitate to ask.

Thank you!

# DP 
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import sys
import os
import time
import random
import numpy as np
import pandas as pd
from datetime import datetime
#from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        data = torch.randn(length, size)
        output = torch.from_numpy(np.random.choice([0, 1], size=(length,1), p=[1./3, 2./3]))
        self.data = torch.cat((data, output), axis=1)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class Model(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())

        return output

# record time:
'''
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
'''
start_time = time.time()
# Parameters and DataLoaders
input_size = 5
output_size = 2

batch_size = 30
data_size = 100

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                         batch_size=batch_size, shuffle=True)
model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)
model.to(device)
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


for epoch in range(10):
    for data in rand_loader:
        inpt = data[:,:-1].to(device)
        real = data[:,-1].to(device)
        real=real.type(torch.cuda.LongTensor)
        pred = model(inpt)
        output = loss(pred, real)
        output.backward()
        optimizer.step()
        print("Outside: input size", inpt.size(),
              "output_size", output.size())
#end.record()
#print("time elapsed is: ", start.elapsed_time(end))
print("--- %s seconds ---" % (time.time() - start_time))
# 10 epochs: 4085.677978515625 ms = ~ 5secs
# DDP Example
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import sys
import os
import time
import random
import numpy as np
#from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
from numpy.random import randint
#from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
torch.autograd.set_detect_anomaly(True)
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        data = torch.randn(length, size)
        output = torch.from_numpy(np.random.choice([0, 1], size=(length,1), p=[1./3, 2./3]))
        self.data = torch.cat((data, output), axis=1)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class Model(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())

        return output

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def main(rank, world_size):
    input_size = 5
    output_size = 2
    batch_size = 30
    data_size = 100
    pin_memory=True
    num_workers=4
    manualSeed = 2
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    
    # setups
    setup(rank, world_size)
    df = RandomDataset(input_size, data_size)
    sampler = DistributedSampler(df, num_replicas=world_size, rank=rank) 
    rand_loader = DataLoader(dataset=df, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
    model = Model(input_size, output_size).to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)
    print("working on rank", rank)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    

    for epoch in range(10):
        for data in rand_loader:
            inpt = data[:,:-1].to(rank)
            real = data[:,-1].to(rank)
            real=real.type(torch.cuda.LongTensor)
            pred = model(inpt)
            output = loss(pred, real)
            output.backward()
            optimizer.step()
            #print("Outside: input size", inpt.size(),
            #      "output_size", output.size())



if __name__ == '__main__':

    # suppose we have 2 gpus
    world_size = 2
    print("Using Torch version: ", torch.__version__)
    print("Model training is about to start...")
    '''
    # record time:
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    '''

    start_time = time.time()

    mp.spawn(
        main,
        args=([world_size]),
        nprocs=world_size,
    )
    print("--- %s seconds ---" % (time.time() - start_time))
    #10 epochs: 25seconds

using:
pytorch 1.11.0 py3.9_cuda11.3_cudnn8.2.0_0 pytorch
pytorch-mutex 1.0 cuda pytorch

TL;DR: With a few minor changes, I see parity for the training loop time on both DP and DDP for your small example on 2 GPUs. Moreover, with some more computation and a larger world size, DDP outperforms DP.

Changes:

  • [DDP] Set num_workers = 0. DDP prefers single-process data loading.
    • This leads to a significant speedup.
  • [DDP] Comment out the CUDA_LAUNCH_BLOCKING=1, TORCH_DISTRIBUTED_DEBUG=DETAIL, and torch.autograd.set_detect_anomaly(True).
  • [DDP] Change the per-rank batch size to 30 // world_size.
    • DP partitions the input and distributes it to each rank, while DDP expects each rank to handle its own batch.
    • This change makes the global effective batch size equal for DP and DDP (with each worker processing an input of shape [15, 5] for a world size of 2).
  • [DP][DDP] Add a torch.cuda.synchronize() before measuring the elapsed time
    • This ensures that the final optimizer step’s kernels finish before measuring the elapsed time.

Notes:

  • mp.spawn() has some initialization overhead that contributes to DDP’s slowness.
    • In practice, this is amortized over the long-running training job.
  • The setup for DDP takes longer than DP: 2.8 seconds versus 1.3 seconds, respectively, on my machine.
    • DDP synchronizes the model across ranks to ensure that they start from the same model parameter values. This requires one or more broadcast() calls, which are time consuming since they synchronize across ranks.
  • Only considering the training loop, the timing becomes equal: both around 2.7 seconds with some variance.
  • DP suffers from Python GIL contention, so it scales poorly with larger world sizes. DDP has one process per worker, so it avoids the Python GIL contention.
    • If I increase the arithmetic intensity by increasing the data_size to 10000 instead of 100 and if I use 4 GPUs, I see DP takes 19.4 seconds, while DDP takes 13 seconds (with some variancee), where both measurements are in total including the setup time.

Code:

# DP 
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import sys
import os
import time
import random
import numpy as np
import pandas as pd
from datetime import datetime
#from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        data = torch.randn(length, size)
        output = torch.from_numpy(np.random.choice([0, 1], size=(length,1), p=[1./3, 2./3]))
        self.data = torch.cat((data, output), axis=1)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class Model(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        # print("\tIn Model: input size", input.size(),
        #       "output size", output.size())

        return output

# record time:
'''
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
'''
start_time = time.time()
# Parameters and DataLoaders
input_size = 5
output_size = 2

batch_size = 30
data_size = 10000

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                         batch_size=batch_size, shuffle=True)
model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)
model.to(device)
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

print(f"time for setup: {time.time() - start_time:.3f}")


for epoch in range(10):
    for data in rand_loader:
        inpt = data[:,:-1].to(device)
        real = data[:,-1].to(device)
        real = real.type(torch.cuda.LongTensor)
        pred = model(inpt)
        output = loss(pred, real)
        output.backward()
        optimizer.step()
        # print("Outside: input size", inpt.size(),
        #       "output_size", output.size())
#end.record()
#print("time elapsed is: ", start.elapsed_time(end))
torch.cuda.synchronize()
print("--- %s seconds ---" % (time.time() - start_time))
# 10 epochs: 4085.677978515625 ms = ~ 5secs
# DDP Example
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import sys
import os
import time
import random
import numpy as np
#from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn
import torch.multiprocessing as mp
from numpy.random import randint
#from torch.utils.tensorboard import SummaryWriter
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
# torch.autograd.set_detect_anomaly(True)
# os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        data = torch.randn(length, size)
        output = torch.from_numpy(np.random.choice([0, 1], size=(length,1), p=[1./3, 2./3]))
        self.data = torch.cat((data, output), axis=1)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class Model(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        # print("\tIn Model: input size", input.size(),
        #       "output size", output.size())

        return output

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)


def main(rank, world_size):
    input_size = 5
    output_size = 2
    batch_size = 30 // world_size
    data_size = 10000
    pin_memory = True
    num_workers = 0
    manualSeed = 2
    random.seed(manualSeed)
    torch.manual_seed(manualSeed)

    start_time = time.time()
    
    # setups
    setup(rank, world_size)
    df = RandomDataset(input_size, data_size)
    sampler = DistributedSampler(df, num_replicas=world_size, rank=rank) 
    rand_loader = DataLoader(dataset=df, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers, drop_last=False, shuffle=False, sampler=sampler)
    model = Model(input_size, output_size).to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)
    print("working on rank", rank)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    print(f"[Rank {rank}] time for setup: {time.time() - start_time:.3f} s")
    

    for epoch in range(10):
        for data in rand_loader:
            inpt = data[:,:-1].to(rank)
            real = data[:,-1].to(rank)
            real = real.type(torch.cuda.LongTensor)
            pred = model(inpt)
            output = loss(pred, real)
            output.backward()
            optimizer.step()
            #print("Outside: input size", inpt.size(),
            #      "output_size", output.size())
    torch.cuda.synchronize()


if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    print("Using Torch version: ", torch.__version__)
    print(f"Model training is about to start with world size {world_size}...")
    '''
    # record time:
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    '''

    start_time = time.time()

    mp.spawn(
        main,
        args=([world_size]),
        nprocs=world_size,
    )
    print("--- %s seconds ---" % (time.time() - start_time))
    #10 epochs: 25seconds

Thank you! this was really helpful!