Following this paper from Mohan et al., I am willing to remove the bias terms from BatchNorm2d
layers. An implementation by the authors is available here, but I’d like to be as close as possible from the initial implementation.
Re-using @ptrblck implementation here, I wrote the following code:
class MyBFBatchNorm2d(nn.BatchNorm2d):
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
use_bias=False,
):
super(MyBFBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
self.use_bias = use_bias
def forward(self, input):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
if self.use_bias:
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
n = input.numel() / input.size(1)
with torch.no_grad():
if self.use_bias:
self.running_mean = (
exponential_average_factor * mean
+ (1 - exponential_average_factor) * self.running_mean
)
# update running_var with unbiased var
self.running_var = (
exponential_average_factor * var * n / (n - 1)
+ (1 - exponential_average_factor) * self.running_var
)
else:
if self.use_bias:
mean = self.running_mean
var = self.running_var
if self.use_bias:
input = input - mean[None, :, None, None]
input = input / (torch.sqrt(var[None, :, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None]
if self.use_bias:
input = input + self.bias[None, :, None, None]
return input
When instantiating a MyBFBatchNorm2d
object and passing dummy data through it, it reacts as expected, i.e. self.running_var
is updated while self.running_mean
stays 0.
However, when using this custom layer inside of a neural net, it does not converge anymore whereas it worked with the author’s implementation.
Is my custom implementation incorrect?