I need a BatchNorm2d that I can freeze, so that subsequent calls all behave exactly the same as the previous, but so far this doesn’t work! I am really not sure why though. The computations need exact replication, otherwise it won’t work. Here is my code attempting to replicate the BatchNorm2d function, I tried to follow the official documentation:
import torch.nn as nn
_global_freeze_BN = False
def freeze_fct(freeze):
global _global_freeze_BN
_global_freeze_BN = freeze
def get_bn_freeze():
global _global_freeze_BN
return _global_freeze_BN
_global_log_BN = False
def log_bn_fct(log):
global _global_log_BN
_global_log_BN = log
class WrappedBatchnorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None):
self.last_mean = None
self.last_var = None
super(WrappedBatchnorm2d, self).__init__(num_features, eps, momentum, affine, track_running_stats, device, dtype)
def forward(self, input):
if self.training and not _global_freeze_BN:
if _global_log_BN:
print("bn activated")
self.last_mean = input.mean([0, 2, 3]).detach()
self.last_var = input.var([0, 2, 3], unbiased=False).detach()
return super(WrappedBatchnorm2d, self).forward(input)
if _global_freeze_BN:
if _global_log_BN:
print("bn deactivated")
#assert not self.training
input = (input - self.last_mean[None, :, None, None]) / (torch.sqrt(self.last_var[None, :, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
return input
return super(WrappedBatchnorm2d, self).forward(input)
I test it via:
bn_test = WrappedBatchnorm2d(3, affine=False, eps=0.0)
data = torch.rand((10,3,24,24))
#data = torch.zeros((10,3,24,24))
freeze_fct(False)
assert not get_bn_freeze()
bn_outputs = bn_test(data)
freeze_fct(True)
print("BREAK")
assert get_bn_freeze()
bn_outputs_2 = bn_test(data)
assert torch.allclose(bn_outputs, bn_outputs_2)
# print(abs(loss2.item() - loss_val.item()))
# assert abs(loss2.item() - loss_val.item()) < 1e-6
# outputs2 = cifar_resnet_50_sgd_adapt_2(inputs)
# loss2 = criterion(outputs2, labels)
# print(abs(loss2.item() - loss_val.item()))
# assert abs(loss2.item() - loss_val.item()) < 1e-6
freeze_fct(False)
assert not get_bn_freeze()
It doesn’t work and I don’t know why. I think I followed the documentation in an exact way.