Batchnorm running statistics issue

I m probably missing something obvious, but i cant figure it out. I tried to reproduce partly the running statistics of a BatchNorm2d layer (while on train mode) but i dont seem to be able to get them right. It’s not like it matters, but i m just curious at this point what i m doing wrong. Here is my attempt:

class BatchNorm2d():
    
    def __init__(self, epsilon=1e-05, momentum=0.1): # Using the same defaults as pytorch
        """
        Expecting a tensor of shape (BS:[batch_size], C:[channels], H:[height], W:[width]).

        """
        self.eval_mode = True
        self.epsilon = epsilon
        self.momentum = momentum
        self.batches_processed_while_training = 0
        self.running_tensor_mean = 0
        self.running_tensor_var = 1.0
        
    def __call__(self, tensor):
        
        if  not self.eval_mode: # TRAINING
            self.C = tensor.shape[1]
            current_tensor_mean = tensor.mean((0,2,3)) # Mean over the batch (0), the height(2) and the width (3) --> Shape: (3,)
            current_tensor_var = tensor.var((0,2,3), unbiased=False) # Variance (unbiased) over the batch (0), the height(2) and the width (3) --> Shape: (3,)
 
            self.running_tensor_mean = (1-self.momentum)* self.running_tensor_mean + (self.momentum * current_tensor_mean)
            self.running_tensor_var =  (1-self.momentum)* self.running_tensor_var + (self.momentum * current_tensor_var)
            
            self.batches_processed_while_training += 1
            # The one 1-d (singleton) vectors need to be reshaped to (C, 1, 1) so that broadcasting will work as expected.
            return (tensor - current_tensor_mean.reshape(self.C,1,1)) / torch.sqrt(current_tensor_var.reshape(self.C,1,1) + self.epsilon)
        
        else:
            pass

And when fed sequentially with 3 batches (of size 10) each of which is generated by torch.random.randn((10,3,8,8)) the output is the following (the ‘‘output’’ rows are [0,0,0,0:5] slices of the [10,3,8,8] shapes):

Batch1:
Output: tensor([-1.7355, 1.2432, -0.9514, 1.0702, 0.5833])
running mean: tensor([ 0.0031, -0.0098, -0.0050])
running var: tensor([1.0116, 0.9778, 1.0031])

Batch2:
Output: tensor([ 0.2073, -0.0832, 0.7110, 0.8565, 0.5534])
running mean: tensor([ 0.0332, 0.0244, -0.0110])
running var: tensor([1.0217, 0.9842, 0.9913])

Batch3:
Output: tensor([-0.4904, -0.3256, 2.0108, -1.0837, 1.0151])
running mean: tensor([0.0276, 0.0215, 0.0227])
running var: tensor([0.9963, 0.9632, 0.9768])

I compare the results with the ones obtained by torch.nn.BatchNorm2d on train mode:

Batch1:
Output: tensor([-1.7355, 1.2432, -0.9514, 1.0702, 0.5833], grad_fn=)
running mean: tensor([ 0.0031, -0.0098, -0.0050])
running var: tensor([1.0122, 0.9784, 1.0037])

Batch2:
Output: tensor([ 0.2073, -0.0832, 0.7110, 0.8565, 0.5534], grad_fn=)
running mean: tensor([ 0.0332, 0.0244, -0.0110])
running var: tensor([1.0227, 0.9852, 0.9923])

Batch3
Output: tensor([-0.4904, -0.3256, 2.0108, -1.0837, 1.0151], grad_fn=)
running mean: tensor([0.0276, 0.0215, 0.0227])
running var: tensor([0.9975, 0.9644, 0.9780])

Why is the running variance wrong? If not clear, i could provide a github link i guess. Thank you in advance.

The running_var update is performed using the unbiased variance as seen in my manual example and the PyTorch backend implementation.

Indeed that was it. I was playing around with it and forgot to turn it back to True (can be confirmed by the comment that that was my intention). Thank you.