How to use SyncBatchNorm in nn.parallel.DistributedDataParallel with v1.1.0?

Could you please post a short code to introduce the instructions of it?
I have a machine with two GPUs, which means I want to use single process multi gpus.
I tried to use SyncBatchNorm, but failed, sadly like this …

It raise a “ValueError: SyncBatchNorm is only supported for DDP with single GPU per process”…!
But in docs of DDP, it says single-process multi-gpus.

import torch
import torch.nn as nn
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.convBlock = nn.Sequential(
            nn.Conv2d(3, 128, 3, 1, 1),
            nn.SyncBatchNorm(128),
            nn.ReLU(),
            nn.Conv2d(128, 512, 3, 1, 1),
            nn.SyncBatchNorm(512),
            nn.ReLU(),
            nn.Conv2d(512, 1, 3, 1, 1),
            nn.SyncBatchNorm(1),
            nn.ReLU()
        )
    def forward(self, x):
        x = self.convBlock(x)
        return x
torch.distributed.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:12345', world_size=1, rank=0)

model = net().cuda()
model = nn.parallel.DistributedDataParallel(model, device_ids=[0, 1], output_device=0)
model = model
 optimizer = torch.optim.Adam(model.parameters())
mseloss = torch.nn.L1Loss()
for i in range(1000):
    x = torch.rand(10, 3, 224, 224)
    y = torch.rand(10, 1, 224, 224)
    x = x.cuda()
    y = y.cuda()
    out = model(x)
    optimizer.zero_grad()
    loss = mseloss(out, y)
    print(i, loss)
    loss.backward()
    optimizer.step()

This is expected.

While DDP supports using multiple GPUs from a single process, nn.SyncBatchNorm does not and requires you to use a single GPU per process. Also see the docs for torch.nn.SyncBatchNorm:

Currently SyncBatchNorm only supports DistributedDataParallel with single GPU per process. Use torch.nn.SyncBatchNorm.convert_sync_batchnorm() to convert BatchNorm layer to SyncBatchNorm before wrapping Network with DDP.

1 Like

I think this is worth fixing. Distributed data parallel uses a lot of CPU threads. This is okay for expensive servers used by industry, but a lot of us have a limited number of CPU cores at our disposal.