Implementing Batchnorm in Pytorch. Problem with updating self.running_mean and self.running_var

It depends a bit which implementation you are using when you mention the “native” implementation.
E.g. since you are using the GPU, cudnn would be used, which would provide fast algorithms for the batchnorm operations. To disable it, you could use torch.backends.cudnn.enabled = False and compare the speed again to your custom implementation to see, if you would be inside the desired 5-10% performance drop window.

Hi,
I tried the manual implementation of Batch Normalization. However, the training accuracy seems to fail when I use it. I am not sure what the cause of this maybe.
I looked at the running mean and variance, and it the tensor is all zeros.
Thanks for your help in advance.

I don’t know which manual implementation you are using, but my reference should update the running stats properly.

Hi,
In my case, I use multipy nn.BatchNorm1d to construct several domain specific BN, and I design the dataloader to allocate the data from different domains to different gpus, such that different BNs get data from different gpus, e.g. ‘gpu1: bn1, gup2: bn2, …, gpu_n: bn_n’. After training, the bn.weight and bn.bias seems update properly, but the only the running_mean (or running_var) of the first bn (on gpu:0) was updated, the mean and var of bn_2, …, bn_n are entirely not updated. How can I solve this problem? Any advices? Thanks

I think you could use SyncBatchNorm to synchronize the stats.

Thank you for your quickly reply. The loss and data I employed is not friendly for DDP training, and I use DP for multi-gpu training. Thus SyncBatchNorm seems not work for me. As you say,

Is there any functions in for DP to synchronize buffer on all device?

I think you are right and nn.DataParallel is not compatible with SyncBatchNorm and I don’t know how you could synchronize the stats as the models will be copied from the default device in each iteration. Maybe changing the momentum might help.

I’m trying to train the EfficientNetV2 using your manual implementation but it doesn’t work.
The output of your batch norm is not exactly the same as Pytorch BatchNorm.
I made this small code just to show the problem:

import torch
import torch.nn as nn

class MyBatchNorm2d(nn.BatchNorm2d):
    def forward(self, x):
        var, mean = torch.var_mean(x, dim=[0, 2, 3], unbiased=False)
        x = (x - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        x = x * self.weight[None, :, None, None] + self.bias[None, :, None, None]
        return x

dtype = torch.float
# dtype = torch.double
my_bn = MyBatchNorm2d(3, affine=True).to(dtype)
bn = nn.BatchNorm2d(3, affine=True).to(dtype)
for i in range(10):
    scale = torch.randint(1, 10, (1,)).to(dtype)
    bias = torch.randint(-10, 10, (1,)).to(dtype)
    x = torch.randn(10, 3, 100, 100).to(dtype) * scale + bias
    out1 = my_bn(x)
    out2 = bn(x)

    if not torch.allclose(out1, out2):
        print(f'Max diff: {(out1 - out2).abs().max():.10f}', )

And if you run this code using the double data type (uncomment line 12), there are no discrepancies.
I think this is the problem with my training.
So, does Pytorch use CuDNN on CPU too?
Is there a way to solve this?

The differences for float32 are in the range tensor(4.7684e-07, grad_fn=<MaxBackward1>) and for float64 in tensor(5.5067e-14, dtype=torch.float64, grad_fn=<MaxBackward1>), which is expected due to the limited floating point precision of these numerical formats.
You cannot expect to have bitwise identical results if different algorithms are used.

OK, so even by using float64 I’m not able to train it.
I’m using your implementation as a drop-in replacement for nn.BatchNorm2d:

import torch
import timm

def convert_batchnorm(module):
    module_output = module
    if isinstance(module, torch.nn.modules.batchnorm.BatchNorm2d):
        module_output = MyBatchNorm2d(
            module.num_features,
            module.eps,
            module.momentum,
            module.affine,
            module.track_running_stats
        )
        if module.affine:
            with torch.no_grad():
                module_output.weight = module.weight
                module_output.bias = module.bias
        module_output.running_mean = module.running_mean
        module_output.running_var = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
    for name, child in module.named_children():
        module_output.add_module(name, convert_batchnorm(child))
    del module
    return module_output


model = timm.create_model('tf_efficientnetv2_s')
model = convert_batchnorm(model)

Have you trained a network this way?
Do you think it should be possible?

No, I haven’t trained a model with it as it’s mostly a debugging implementation which I used for some reference implementations. The script should contain numerical tests which were all passing. In case you find any issue with a mismatch, please let me know.
The posted errors are expected and a mismatch of 1e-14 should not cause any divergence (the numerical mismatch between CPU and GPU would of course be larger in float32 and doesn’t cause divergence).