DistributedDataParallel slows down training with big fully connected layers

When I launch the following script with the torch.distributed.launch utilility on a 2 GPUs machine, I get a much slower (10x) training than when I launch it on a single GPU.

I realized that it seems to come from the big fully connected layer at the end of the network (130000x1024), and I suppose it is because the gradients that need to be synchronized at each iteration represent a big amount of memory.
I profiled the code with Nvidia Nsight Systems and saw that there is a call to ncclAllReduceRingLLKernel_sum_f32 that takes approximately 500 ms each iteration.

Is this expected behaviour with this kind of network? Or am I doing something wrong?

import torch
import torch.nn as nn
import argparse
from torch.nn.parallel import DistributedDataParallel as DPP
import torch.nn.functional as F
from tqdm import tqdm

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        in_channels = 3
        out_channels = 64
        depth = 7

        m_features = [
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
        for i in range(depth):
            in_channels = out_channels
            if i % 2 == 1:
                stride = 1
                out_channels *= 2
                stride = 2
                in_channels, out_channels, 3, padding=1, stride=stride,

        self.features = nn.Sequential(*m_features)

        patch_size = 256 // (2 ** ((depth + 1) // 2))
        m_classifier = [
            nn.Linear(out_channels * patch_size ** 2, 1024),
            nn.LeakyReLU(negative_slope=0.2, inplace=False),
            nn.Linear(1024, 1)
        self.classifier = nn.Sequential(*m_classifier)

    def forward(self, f0):
        features = self.features(f0)
        output = self.classifier(features.view(features.size(0), -1))

        return output

torch.backends.cudnn.enabled = True  # make sure to use cudnn for computational performance

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)

args = parser.parse_args()

def train(rank, world_size):

    if world_size > 1:
        torch.distributed.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size)

    discriminator = Discriminator()
    optimizer = torch.optim.Adam(

    # -- Initialize model for distributed training --
    if torch.cuda.device_count() > 1:
        discriminator = DPP(discriminator, device_ids=[rank])

    frame = torch.rand((1, 3, 256, 256), device=f"cuda:{rank}")
    d_01 = discriminator(frame)

    label_01 = torch.zeros_like(d_01)

    for i in tqdm(range(30)):

        # - Compute loss -
        d_01 = discriminator(frame)
        loss = F.binary_cross_entropy_with_logits(d_01, label_01)


def main():
    world_size = torch.cuda.device_count()

    with torch.autograd.profiler.emit_nvtx():
        train(args.local_rank, world_size)

if __name__ == '__main__':

Does anyone have an insight on this?

so looking at your code, it looks like you didn’t create the process groups in different processes, and ended up just using one process, or are you just running the scripts on multi hosts? If not, did you try following the DDP tutorial, launch it in multiple processes and see if improves the performance?

Thank you for your answer :slight_smile: I did follow the tutorial. I use the torch.distributed.launch utility that takes care of creating one process per GPU. I am using it on one machine with two GPUs.

It is true that I put my two models in the same (default) process group, but I also tried to use different process groups for both models and it did not change anything.

If the cost is dominated by allreduce communication, can you try no_sync context manager to reduce the sync frequency?

Moreover, you probably can try to register a FP16 compression communication hook to compress the gradients before allreduce. It’s one-line code change. You can try even more advanced gradient compression if interested.

I am really surprised at such a high allreduce cost.


Thank you for your answer, I am sure your suggestion would have helped!

I finally understood the source of the problem. This network contains this layer: Linear(in_features=131072, out_features=1024, bias=True)

Which requires a huge number of gradients to be synchronized (~400 MB per iteration in fp32). As I don’t have any NVLink it makes sense that this synchronisation takes 500 ms I think.

So I ended up redesigning the network to avoid having such a huge fully connected layer.