Manual vs PyTorch calculation for BatchNormalization difference? Why?

Hi,
In order to better understand BN of PyTorch, I did some manual calculations as follows, but the output differs. Could anyone help check and find the reason? Thank you very much : )

Code

import torch 
from torch import nn 

# PyTorch calculation
layer = nn.BatchNorm2d(3)
layer.train()
input = torch.ones(2,3,3,4)
output = layer(input)
print(layer.running_mean)


# Manual calculation
x_mean = input[:,0,:,:].mean()
momentum = nn.BatchNorm2d(3).momentum
running_mean = nn.BatchNorm2d(3).running_mean # initial values: 0
running_mean = momentum * running_mean + (1 - momentum) * x_mean # 0.1*0 + (1-0.1)*1 = 0.9 
print(running_mean)

output

tensor([0.1000, 0.1000, 0.1000])
tensor([0.9000, 0.9000, 0.9000])
running_mean = (1 - momentum) * running_mean + momentum * x_mean

https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html?highlight=batchnorm#torch.nn.BatchNorm2d

.. note::
        This :attr:`momentum` argument is different from one used in optimizer
        classes and the conventional notion of momentum. Mathematically, the
        update rule for running statistics here is
        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
        new observed value.
1 Like

Hi Eta_C
Sincerely thanks for your timely reply.
I followed your tips and everything goes as I wish.
Sincrerely thanks for your help : )