Distributed Data Parallel slower than Data Parallel

Hey @TT_YY, I took a closer look at the code and noticed that you converted BatchNorm to SyncBatchNorm for DDP, which might be the source of the slowness. If you look at SyncBatchNorm's implementation (see below), it launches its own communication, which is not handled by DDP. This additional comm leads to ~10% slowdown in your program when running on 2 GPUs. When I use BatchNorm instead of SyncBatchNorm, DDP is faster than DP. In general, when comparing DDP and DP speed, we need to make sure that they run the same model.

This is how I measure the latency.

# run one iteration to warm up
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
loss_val = optimizer.step(loss.item) 

# measure latency of the second iteration
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
loss_val = optimizer.step(loss.item)
end.record()
torch.cuda.synchronize()

print(f"world size = {args.world_size}, batch size = {batch_size}, latency = {start.elapsed_time(end)}")

I tried to run the DDP script with the following configs on two GPUs:

  1. Run as is

    world size = 2, batch size = 2048, latency = 506.9587707519531
    world size = 2, batch size = 2048, latency = 506.40606689453125
    
  2. Comment out the following line, as SyncBatchNorm has its own way to communicate buffers, which can e slower.

    #net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    
    world size = 2, batch size = 2048, latency = 456.42352294921875
    world size = 2, batch size = 2048, latency = 457.8104248046875
    
  3. Made the following edits and set args.n_gpus = 1. So the program runs DataParallel.

    #net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    ...
    #net = nn.parallel.DistributedDataParallel(net, device_ids=[gpu])
    net = nn.parallel.DataParallel(net)
    
    world size = 1, batch size = 4096, latency = 496.3483581542969
    
1 Like