SyncBatchNorm not working with autocast and mixed-precision

I’m trying to use torch.nn.SyncBatchNorm.convert_sync_batchnorm in my DDP model. I am currently able to train with DDP no problem while using mixed-precision with torch.cuda.amp.autocast but it is not working with torch.nn.SyncBatchNorm. I am running PyTorch=1.8.1 and python 3.8 with Cuda=10.2. Here is how I am setting up the model.

    net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net = net.to(device)
    net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[rank], find_unused_parameters=False)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    scaler = GradScaler()

    for epoch in range(starting_epoch, epochs):
        for idx, batch in enumerate(train_loader):
            with autocast():
                   pred = net(batch['data'])
                   loss = loss_fn(pred, batch['target'])

            for param in net.parameters():
                param.grad = None
            scaler.scale(loss).backward()

This works no problem training when training normally, but when adding in torch.nn.SyncBatchNorm I am getting the error

  File "/home/.conda/envs/main_env_2/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 545, in forward
    return sync_batch_norm.apply(
  File "/home/.conda/envs/main_env_2/lib/python3.8/site-packages/torch/nn/modules/_functions.py", line 38, in forward
    mean, invstd = torch.batch_norm_gather_stats_with_counts(
RuntimeError: expected scalar type Half but found Float

I also tried wrapping the torch.nn.SyncBatchNorm with autocast, but it did not work.

I cannot reproduce the issue using:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
import types
import argparse


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn = nn.BatchNorm2d(16)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


def main():
    parser = argparse.ArgumentParser(description='fdsa')
    parser.add_argument("--local_rank", default=0, type=int)

    args = parser.parse_args()
    args.gpu = args.local_rank
    torch.cuda.set_device(args.gpu)
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='env://')
    args.world_size = torch.distributed.get_world_size()

    model = MyModel().to(args.gpu)
    model = DistributedDataParallel(
        model,
        device_ids=[args.gpu],
        output_device=args.local_rank,
        broadcast_buffers=False,
        find_unused_parameters=False
    )
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
    scaler = torch.cuda.amp.GradScaler()
    criterion = nn.MSELoss()

    for i in range(2):
        model.zero_grad()
        x = torch.randn(2, 3, 16, 16, device=args.gpu)
        target = torch.randn(2, 16, 16, 16, device=args.gpu)
        with torch.cuda.amp.autocast():
            out = model(x)
            loss = criterion(out, target)

        scaler.scale(loss).backward()
        

if __name__ == "__main__":
    main()

Could you post an executable code snippet or adapt mine to reproduce the error?

I adapted yours to reproduce the result. The issue with the above code is that since world_size=1 it doesn’t trigger the proper code in nn.SyncBatchNorm (lines 528 to 548 of torch/nn/modules/batchnorm.py). This code triggers the error.

import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.distributed as dist
import argparse
from torch.cuda.amp import autocast, GradScaler
import os
import torch.multiprocessing as mp


class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, 3, 1, 1)
        self.bn = nn.BatchNorm2d(16, track_running_stats=False)

    def forward(self, x):
        with autocast():
            x = self.conv(x)
            x = self.bn(x)
            return x


def main(gpu):
    parser = argparse.ArgumentParser(description='fdsa')
    parser.add_argument("--local_rank", default=0, type=int)

    args = parser.parse_args()
    args.gpu = args.local_rank
    torch.cuda.set_device(args.gpu)

    rank = gpu
    device = f'cuda:{gpu}'

    dist.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=2,
        rank=rank)

    args.world_size = torch.distributed.get_world_size()

    model = MyModel()
    model = model.to(device)
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    model = DistributedDataParallel(
        model,
        device_ids=[rank],
        output_device=rank,
        broadcast_buffers=False,
        find_unused_parameters=False
    )

    scaler = GradScaler()
    criterion = nn.MSELoss()

    for i in range(2):
        model.zero_grad()
        x = torch.randn(2, 3, 16, 16).to(device)
        target = torch.randn(2, 16, 16, 16).to(device)
        with autocast():
            out = model(x)
            loss = criterion(out, target)

        scaler.scale(loss).backward()

    print('completed')

if __name__ == "__main__":
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"
    mp.spawn(main, nprocs=2)

I don’t think the difference is created by the world_size, but in the batchnorm layer, which doesn’t track any stats in your example and would make SyncBatchNorm useless. If you enable it, your code snippet works on my machine.
However, thanks for the code snippet, as the error message should be improved at least.