torch.nn.parallel.DistributedDataParallel slower than torch.nn.DataParallel

I recently built a computer with a dual GPU setup, in particular two 3090’s. I wanted to benchmark the performance increase using the recommended torch.nn.parallel.DistributedDataParallel module, and I found an actual decrease in performance which I’m not sure how to account for.

My code basically works by creating a dataset of random images, and feeding them through a resnet-50 model for 5 epochs, using random labels (for various batch sizes). To my surprise, the DDP setup is significantly slower than the single gpu/data parallel approach. Here is my code for reference. The observed performance hit is apparent in both Windows and Linux (Ubuntu).

import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision.models import resnet50
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from sys import platform
import torch.multiprocessing as mp


def is_windows():
    return platform == "win32"


def spawn_processes(fn, world_size):
    mp.spawn(fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo" if is_windows() else "nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class RandomImagesDataloader(Dataset):
    def __init__(self, num_images=500, height=600, width=600, num_channels=3):
        self.num_images = num_images
        self.dataset = torch.randn(num_images, num_channels, height, width)
        self.len = num_images

    def __getitem__(self, index):
        return self.dataset[index]

    def __len__(self):
        return self.len


def train(rank, world_size):
    setup(rank, world_size)
    model = resnet50().to(rank)
    model = DDP(model, device_ids=[rank])
    batch_sizes = [4, 8, 16, 20]
    num_epochs = 5
    # warmup iterations
    for i in range(10):
        sample_input = torch.rand(10, 3, 600, 600).to(rank)
        _ = model(sample_input)
    for batch_size in batch_sizes:
        dl = DataLoader(dataset=RandomImagesDataloader(),
                        batch_size=batch_size, shuffle=True,
                        num_workers=1, drop_last=True)
        optimizer = optim.SGD(params=model.parameters(), lr=1e-3)
        loss_fn = nn.CrossEntropyLoss()
        total_time = 0.0
        for epoch_num in range(num_epochs):
            for batch_num, batch in enumerate(dl):
                start_event, end_event = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
                start_event.record()
                targets = torch.randint(size=(batch_size,), low=0, high=1000).long().to(rank)
                batch = batch.to(rank)
                output = model(batch)
                loss = loss_fn(output, targets)
                loss.backward()
                optimizer.step()
                end_event.record()
                torch.cuda.synchronize()
                total_time += start_event.elapsed_time(end_event)
        if rank == 0:
            print(f"The estimated training time for {world_size} gpu/s at batch size "
                  f"{batch_size} is {round(total_time/1000.0, 3)} seconds")
    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    spawn_processes(train, world_size)

Compared to my code here, which tests both a single gpu setup, and a dataparallel setup (which is explicitly not recommended in the docs and multiple other sources), the above code is significantly slower (at all batch sizes).

I’m using PyTorch version 1.11. Any insights appreciated as to what could be going wrong, or if this is expected behavior.

Forcing the CPU to sleep to simulate synchronizations isn’t the right approach:

        # dealing with synchronization issues
        time.sleep(1.0)
        t2 = time.time()

and you should either synchronize the code via torch.cuda.synchronize() or check the profile via e.g. Nsight Systems to see how the workload execution looks.

I am now using cuda events with synchronization (I updated the github example as well), same issue as before occurring. I will investigate Nsight Systems to see what the issue is, but is there anything that sticks out as a potential issue with my code that could be causing this slow down?

What’s particularly odd is the gpu usage is really high for both gpus in the distributed case.

Ok so I wasn’t using a distributed sampler, effectively going over the dataset twice, here is the fixed code.

import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision.models import resnet50
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from sys import platform
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler


def is_windows():
    return platform == "win32"


def spawn_processes(fn, world_size):
    mp.spawn(fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("gloo" if is_windows() else "nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class RandomImagesDataloader(Dataset):
    def __init__(self, num_images=5000, height=224, width=224, num_channels=3):
        self.num_images = num_images
        self.dataset = torch.randn(num_images, num_channels, height, width)
        self.len = num_images

    def __getitem__(self, index):
        return self.dataset[index]

    def __len__(self):
        return self.len


def train(rank, world_size):
    setup(rank, world_size)
    model = resnet50().to(rank)
    model = DDP(model, device_ids=[rank], output_device=rank)
    batch_sizes = [16, 32, 64]
    num_epochs = 5
    # warmup iterations
    for i in range(10):
        sample_input = torch.rand(10, 3, 600, 600).to(rank)
        _ = model(sample_input)
    for batch_size in batch_sizes:
        dataset = RandomImagesDataloader()
        sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=True)
        dl = DataLoader(sampler=sampler, dataset=dataset,
                        batch_size=batch_size, shuffle=False,
                        num_workers=0, drop_last=True)
        optimizer = optim.SGD(params=model.parameters(), lr=1e-3)
        loss_fn = nn.CrossEntropyLoss()
        total_time = 0.0
        for epoch_num in range(num_epochs):
            for batch_num, batch in enumerate(dl):
                start_event, end_event = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
                start_event.record()
                targets = torch.randint(size=(batch_size,), low=0, high=1000).long().to(rank)
                batch = batch.to(rank)
                output = model(batch)
                loss = loss_fn(output, targets)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                end_event.record()
                torch.cuda.synchronize()
                total_time += start_event.elapsed_time(end_event)
        if rank == 0:
            print(f"The estimated training time for {world_size} gpu/s at batch size "
                  f"{batch_size} is {round(total_time/1000.0, 3)} seconds")
    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    spawn_processes(train, world_size)
2 Likes

Thanks for the follow up! Did you see an improvement using DDP vs. DataParallel using the fixed setup?

1 Like

Yes I did with a sufficiently large batch size.

1 Like