Recreating BatchNorm2d computation

I need a BatchNorm2d that I can freeze, so that subsequent calls all behave exactly the same as the previous, but so far this doesn’t work! I am really not sure why though. The computations need exact replication, otherwise it won’t work. Here is my code attempting to replicate the BatchNorm2d function, I tried to follow the official documentation:

import torch.nn as nn
_global_freeze_BN = False

def freeze_fct(freeze):
    global _global_freeze_BN
    _global_freeze_BN = freeze
    
def get_bn_freeze():
    global _global_freeze_BN
    return _global_freeze_BN

_global_log_BN = False

def log_bn_fct(log):
    global _global_log_BN
    _global_log_BN = log
    
class WrappedBatchnorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None):
        self.last_mean = None
        self.last_var = None
        super(WrappedBatchnorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats, device, dtype)
    
    def forward(self, input):
        if self.training and not _global_freeze_BN:
            if _global_log_BN:
                print("bn activated")
            self.last_mean = input.mean([0, 2, 3]).detach()
            self.last_var = input.var([0, 2, 3], unbiased=False).detach()
            return super(WrappedBatchnorm2d, self).forward(input)
            
        if _global_freeze_BN:
            if _global_log_BN:
                print("bn deactivated")
            #assert not self.training
            input = (input - self.last_mean[None, :, None, None]) / (torch.sqrt(self.last_var[None, :, None, None] + self.eps))
            if self.affine:
                input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
            return input
        
        return super(WrappedBatchnorm2d, self).forward(input)

I test it via:

bn_test = WrappedBatchnorm2d(3, affine=False, eps=0.0)
data = torch.rand((10,3,24,24))
#data = torch.zeros((10,3,24,24))

freeze_fct(False)
assert not get_bn_freeze()
bn_outputs = bn_test(data)

freeze_fct(True)

print("BREAK")
assert get_bn_freeze()
bn_outputs_2 = bn_test(data)
assert torch.allclose(bn_outputs, bn_outputs_2)
# print(abs(loss2.item() - loss_val.item()))
# assert abs(loss2.item() - loss_val.item()) < 1e-6
# outputs2 = cifar_resnet_50_sgd_adapt_2(inputs)
# loss2 = criterion(outputs2, labels)
# print(abs(loss2.item() - loss_val.item()))
# assert abs(loss2.item() - loss_val.item()) < 1e-6
freeze_fct(False)
assert not get_bn_freeze()

It doesn’t work and I don’t know why. I think I followed the documentation in an exact way.

I have an even simple problem:

bn_test = torch.nn.BatchNorm2d(3, affine=False, eps=0.0)
bn_test.train()
data = torch.rand((10,3,24,24))
data[:,1] = 2*data[:,1]
data[:,2] = 3*data[:,2]
data = data / torch.sqrt(data.var([0,2,3], unbiased=False)[np.newaxis,:,np.newaxis,np.newaxis])
out_ref = bn_test(data) #normalize var to one

one can see that the variance term is indeed one via bn_test.running_var, combined with our epsilon of zero we can ignore the denominator:

recon_add = data - out_ref
assert torch.allclose(data.mean([0,2,3]), recon_add[0,:,0,0])
assert torch.allclose(data.mean([0,2,3])[np.newaxis,:,np.newaxis,np.newaxis], recon_add)

(no problem here!)
but now: I get a small derivation here

assert torch.allclose(out_ref, (data - data.mean([0,2,3])[np.newaxis,:,np.newaxis,np.newaxis]))

why? this should definitely work! I don’t see any exact specification for the BN in pytorch. I wouldn’t except numerical to be relevant yet? And even if, this also doesn’t work:

assert torch.allclose(out_ref, (data - data.mean([0,2,3])[np.newaxis,:,np.newaxis,np.newaxis])/torch.sqrt(data.var([0,2,3], unbiased=False)[np.newaxis,:,np.newaxis,np.newaxis]))

when comparing the buffers, this works:

assert torch.allclose(0.1*data.mean([0,2,3]), bn_test.running_mean)

this not:

assert torch.allclose(0.9 + 0.1*data.var([0,2,3], unbiased=False), bn_test.running_var)

What exactly is going on?