DDP - Batch Norm Issue

I am having the issue that everyone else has, where a model that uses BatchNorm has poorer accuracy when using DDP:

According to this, I am suppose to patch Batch Norm somehow:

def monkey_patch_bn():
    # print(inspect.getsource(torch.nn.functional.batch_norm))
    def batch_norm(input, running_mean, running_var, weight=None, bias=None,
                   training=False, momentum=0.1, eps=1e-5):
        if training:
            size = input.size()
            size_prods = size[0]
            for i in range(len(size) - 2):
                size_prods *= size[i + 2]
            if size_prods == 1:
                raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

        return torch.batch_norm(
            input, weight, bias, running_mean, running_var,
            training, momentum, eps, False
    torch.nn.functional.batch_norm = batch_norm

But I am not sure how to do it if my code is like this:

def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation, bn_running_avg=False):
    return nn.Sequential(nn.Conv2d(in_planes, out_planes,
                                   kernel_size=kernel_size, stride=stride,
                                   padding=dilation if dilation > 1 else pad,
                                   dilation=dilation, bias=False),


Could you try disabling the CuDNN backend with:
torch.backends.cudnn.enabled = False? According to posts such as Training performance degrades with DistributedDataParallel, can improve training.

Also, have you given SyncBatchNorm (https://pytorch.org/docs/stable/nn.html#syncbatchnorm) a try? This will make batch statistics be computed across all GPUs in usage, instead of being computed separately for the batches passed to each device. (Note that as per the documentation, you’ll have to change your code to spawn a single process per-GPU if you’re not training that way already)

Okay but this is only useful if i am having running_mean_stats enabled right?