Remove bias from BatchNorm

Following this paper from Mohan et al., I am willing to remove the bias terms from BatchNorm2d layers. An implementation by the authors is available here, but I’d like to be as close as possible from the initial implementation.
Re-using @ptrblck implementation here, I wrote the following code:

class MyBFBatchNorm2d(nn.BatchNorm2d):
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,
        affine=True,
        track_running_stats=True,
        use_bias=False,
    ):
        super(MyBFBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats
        )
        self.use_bias = use_bias

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            if self.use_bias:
                mean = input.mean([0, 2, 3])
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            n = input.numel() / input.size(1)
            with torch.no_grad():
                if self.use_bias:
                    self.running_mean = (
                        exponential_average_factor * mean
                        + (1 - exponential_average_factor) * self.running_mean
                    )
                # update running_var with unbiased var
                self.running_var = (
                    exponential_average_factor * var * n / (n - 1)
                    + (1 - exponential_average_factor) * self.running_var
                )
        else:
            if self.use_bias:
                mean = self.running_mean
            var = self.running_var

        if self.use_bias:
            input = input - mean[None, :, None, None]
        input = input / (torch.sqrt(var[None, :, None, None] + self.eps))

        if self.affine:
            input = input * self.weight[None, :, None, None]
            if self.use_bias:
                input = input + self.bias[None, :, None, None]

        return input

When instantiating a MyBFBatchNorm2d object and passing dummy data through it, it reacts as expected, i.e. self.running_var is updated while self.running_mean stays 0.

However, when using this custom layer inside of a neural net, it does not converge anymore whereas it worked with the author’s implementation.

Is my custom implementation incorrect?