Distributed Training slower than DataParallel

Hi all,
I have been using DataParallel so far to train on single-node multiple machines. As i have seen on the forum here that DistributedDataParallel is preferred even for single node and multiple GPUs. So i switched to Distributed training.
My network is kind of large with numerous 3D convolutions so i can only fit a batch size of 1 (stereo image pair) on a single GPU.
I have noticed that the time taken by BackwardPass increases from 0.7 secs to 1.3 secs.
I am have setup the distributed setup as following.
Also GPU utilization is low. Can you kindly suggest what shall be done to increase GPU utilization and reduce backward pass time.
p.s DataLoading does not seem to be the bottleneck as it currently takes 0.08 secs.

if config.distributed_training.enable:
     logging.info(f"spawning multiprocesses with {config.distributed_training.num_gpus} gpus")
     multiprocessing.spawn(  # type: ignore
           _train_model,
           nprocs=config.distributed_training.num_gpus,
           args=(pretrained_model, config, train_dataset, output_dir),
      )
def _train_model(
    gpu_index: int, pretrained_model: str, config: CfgNode, train_dataset: Dataset, output_dir: Path
) -> None:
    train_sampler = None
    world_size = _get_world_size(config)
    local_rank = gpu_index

    if config.distributed_training.enable:
        local_rank = _setup_distributed_process(gpu_index, world_size, config)

    train_sampler = torch.utils.data.DistributedSampler(
        train_dataset, num_replicas=world_size, rank=local_rank
    )

    model = MyModel(config) 
    torch.cuda.set_device(local_rank)
    _transfer_model_to_device(model, local_rank, gpu_index, config)
    .........

def _setup_distributed_process(gpu_index: int, world_size: int, config: CfgNode) -> int:
    logging.info("Setting Distributed DataParallel ....")
    num_gpus = config.distributed_training.num_gpus
    local_rank = config.distributed_training.ranking_within_nodes * num_gpus + gpu_index
    torch.cuda.set_device(local_rank)
    _init_process(rank=local_rank, world_size=world_size, backend="nccl")
    logging.info(f"Done...")
    return local_rank

def _init_process(rank: int, world_size: int, backend="gloo"):
    """ Initialize the distributed environment. """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    dist.init_process_group(  # type:ignore
        backend=backend, init_method="env://", world_size=world_size, rank=rank
    )

def _transfer_model_to_device(model: nn.Module, local_rank: int, gpu_index: int, config: CfgNode) -> None:
    if config.distributed_training.enable:
        model.cuda(local_rank)
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[local_rank], output_device=[local_rank]  # type:ignore
        )
    elif torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        torch.cuda.set_device(gpu_index)
        model = model.cuda(gpu_index)

GPU utilization is the following

1 Like

Have you tried using the NCCL backend? It should be considerably faster than Gloo.

And have measured the time spent on the entire iteration? Most overhead (replicating model, scatter input, gather output) of DataParallel is incurred during the forward pass.

@mrshenli thank you for your reply.
I am using the nccl backend actually.
Following line from the code above.

_init_process(rank=local_rank, world_size=world_size, backend="nccl")

Yes, I have measured the time taken over the entire iteration for both Distributed and DataParallel.
The forward pass takes similar time in both or is a bit faster in DistributedDataParallel (0.75 secs vs 0.8secs in DataParallel).
The overall iteration time in DataParallel is 1.75 secs vs 2.4 secs DistributedDataParallel, where similar time is spend in Dataloading (~0.09 secs).

p.s just saw a typo in the first line of my post. My scenario is Single-node multiple GPUs (not machines).

Hey @tyb_10, I tried a toy model locally, but cannot reproduce this behavior. With 2 GPUs, the code below shows DP is about 9X slower than DDP. Can you try this code in your environment, or can you share a min repro of your code that I can try locally?

DP execution time (ms) by CUDA event: 2938.427490234375
DP execution time (s) by Python time: 2.9386751651763916 
DDP rank-1 execution time (ms) by CUDA event 326.289306640625
DDP rank-0 execution time (ms) by CUDA event 326.19061279296875
DDP rank-1 execution time (s) by Python time 0.3264338970184326 
DDP rank-0 execution time (s) by Python time 0.32636237144470215 
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.parallel import DataParallel as DP

import time

X = 100
B = 200

def ddp_example(rank, world_size):
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
    b = B // world_size
    # create local model
    model = nn.Linear(X, X).to(rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    with torch.cuda.device(rank):
        tik = time.time()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        for _ in range(20):
            # forward pass
            outputs = ddp_model(torch.randn(b, X).to(rank))
            labels = torch.randn(b, X).to(rank)
            # backward pass
            loss_fn(outputs, labels).backward()
            # update parameters
            optimizer.step()
        end.record()
        print(f"DDP rank-{rank} execution time (ms) by CUDA event {start.elapsed_time(end)}")
        torch.cuda.synchronize()
        tok = time.time()
        print(f"DDP rank-{rank} execution time (s) by Python time {tok - tik} ")


def dp_example():
    b = B  # don't need to divide by 2 here as DataParallel will scatter inputs
    model = nn.Linear(X, X).to(0)
    # construct DDP model
    dp_model = DP(model, device_ids=[0, 1])
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(dp_model.parameters(), lr=0.001)

    tik = time.time()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    for _ in range(20):
        # forward pass
        outputs = dp_model(torch.randn(b, X).to(0))
        labels = torch.randn(b, X).to(0)
        # backward pass
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()
    end.record()
    print(f"DP execution time (ms) by CUDA event: {start.elapsed_time(end)}")
    torch.cuda.synchronize()
    tok = time.time()
    print(f"DP execution time (s) by Python time: {tok - tik} ")


def main():
    dp_example()

    world_size = 2
    mp.spawn(ddp_example,
        args=(world_size,),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()


1 Like

@mrshenli thanks a lot for the toy example.
I can reproduce your results and DDP is indeed faster than DP in this case.
I will debug my code shortly and post here the outcome if the problem still persists or otherwise the solution that fixed the issue.

1 Like

@mrshenli so i debugged the code and found one bug that i was not returning the wrapped model to my parent function.
so after these changes DDP and DP take similar time ~ 1.6 seconds per iteration but still DDP is not faster than DP in my case.

.....
if config.distributed_training.enable:
     logging.info(f"spawning multiprocesses with {config.distributed_training.num_gpus} gpus")
     multiprocessing.spawn(  # type: ignore
           _train_model,
           nprocs=config.distributed_training.num_gpus,
           args=(pretrained_model, config, train_dataset, output_dir),
      )

def _train_model(
    gpu_index: int, pretrained_model: str, config: CfgNode, train_dataset: Dataset, output_dir: Path
) -> nn.Module:  **# this was returning None before**
    train_sampler = None
    world_size = _get_world_size(config)
    local_rank = gpu_index
    device = None
    if config.distributed_training.enable:
        local_rank = _setup_distributed_process(gpu_index, world_size, config)

        train_sampler = torch.utils.data.DistributedSampler(
            train_dataset, num_replicas=world_size, rank=local_rank
        )
        device = torch.device(local_rank)
    else:
        device = torch.device("cuda")

    model = MyModel(config) 
    model = _transfer_model_to_device(model, local_rank, gpu_index, config) # **now i assing to the model** 
    dataloader = _initialize_data_loader(train_dataset, config, train_sampler)
    _execute_training(config, device, model, train_sampler, dataloader, output_dir)

def _execute_training(
    config: CfgNode,
    device: torch.device,
    model: nn.Module,
    train_sampler: Optional[torch.utils.data.DistributedSampler],
    dataloader: DataLoader,
    output_dir: Path,
) -> None:
    network_config = config.network_config
    loss_combiner = MultiTaskLoss(num_tasks=2).to(device)
    trainable_params_model = list(filter(lambda p: p.requires_grad, model.parameters()))
    trainable_params = trainable_params_model + list(loss_combiner.parameters())
    optimizer = optim.Adam(
        trainable_params, lr=network_config.learning_rate, betas=(0.9, 0.99), weight_decay=0.001
    )
    logging.info(f"logging to tensorboard at {DEFAULT_TENSORBOARD_LOG_LOCATION}")
    with SummaryWriter(DEFAULT_TENSORBOARD_LOG_LOCATION) as summary_writer:  # type: ignore
        for epoch_idx in range(config.network_config.epochs):
            if config.distributed_training.enable and train_sampler is not None:
                train_sampler.set_epoch(epoch_idx)
            logging.info(f"starting epoch {epoch_idx}")
            _adjust_learning_rate(optimizer, epoch_idx, network_config.learning_rate, network_config.lrepochs)
             #**This is the entry point to the rest of the agnostic code that is same for both DP and DDP**
            _train_batches(
                epoch_idx, dataloader, model, optimizer, device, config, loss_combiner, summary_writer
            )


def _initialize_data_loader(
    train_dataset: Dataset, config: CfgNode, train_sampler: Optional[torch.utils.data.DistributedSampler]
) -> DataLoader:
    network_config = config.network_config
    dataloader = DataLoader(
        train_dataset,
        network_config.batch_size,
        collate_fn=custom_collate,
        shuffle=network_config.training_data_shuffle,
        num_workers=network_config.data_loader_num_workers,
        drop_last=network_config.drop_last_training_sample,
        pin_memory=network_config.data_loader_pin_memory,
        sampler=train_sampler,
    )
    return dataloader

Does your model use any buffers? You can check that by running list(mode.buffers()).

@mrshenli doesn’t look like it , it prints empty list :neutral_face:

print(list(model.buffers()))
[]```
1 Like

Shouldn’t it be

torch.cuda.synchronize()
print(f"DP execution time (ms) by CUDA event: {start.elapsed_time(end)}")

Coz I got RuntimeError: CUDA error: device not ready when execute your code as it is.

Hi I am also facing similar prob. Did u solve ur problem?

hey,
actually soon after this I switched to using Pytorch-lightning and using it for all my setups regarding compute and data. This resolved the issue. Could have been due to some issue in manual setup of DDP within Pytorch at my end. I didn’t investigate further. I would highly recommend using lightning and save yourself a lot of time you might spend in various parts of your setup.

OK, thanks. I solved it. I was using my own data-loader. After using torch.utils.data.DataLoader, the problem vanished. I didn’t debug further to understand the issue.