(different issue from existing batchnorm issue related posts) Batchnorm2D forwardpass transformation is different between eval and train

Hi, so I’ve gone through every single batchnorm related posts I could find and I think there was one related to it but no matter what keywords I put in, I can’t find the post anymore. (I think it suggested running batchnorm without training to gather batch statistics and then do something…)

So here’s the problem.

I have an input of [0, 1.0], whose mean is obviously 0.5, and variance of 0.5.

I’ve verified that during eval(), the following equation is used per documentation.
x_after_bn_during_eval = (x - running_mean) / sqrt(running_var + eps)
with affine=False (no weight or bias term)

So in the following code, after running for 50 iterations or so, batch statistics are gathered sufficiently and BN does estimate running mean of 0.5 and running variance of 0.5 correctly.

And the x_after_bn_during_eval gets computed as ([0, 1] - 0.5) / sqrt(0.5+ 1e-5) = approx [-0.7, 0.7]

import torch
from torch import nn
import numpy as np

class mynet(torch.nn.Module):
    def __init__(self):
        super(mynet, self).__init__()
        self.fc = nn.Linear(2, 1, bias=False)
        self.bn = nn.BatchNorm2d(1, affine=False)


    def forward(self, x, printbn=False):
        net = self.bn(x)

        if printbn:
            print("batchnorm params; running mean: {}, running_var: {}".format(model.bn.running_mean, model.bn.running_var))
            print("before running bn: {}".format(x))
            print("after running bn: {}".format(net))

        net = self.fc(net)
        return net


model = mynet()

data = np.array([0.0, 1.0])
x = torch.from_numpy(data).float().unsqueeze(0).unsqueeze(0).unsqueeze(0)
gt = torch.from_numpy(np.array([0.0])).float().unsqueeze(0).unsqueeze(0).unsqueeze(0)

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.01)

for i in range(100):
    # print("-----------------------------------------iteration: {}----------------------------".format(i))
    # print("------training------")
    model.train()
    # model.firstnet.lowernet.bn.eval()
    optimizer.zero_grad()
    if i == 50:
        result = model(x=x, printbn=True)
        print("------eval------")
        model.eval()
        eval_result = model(x=x, printbn=True)
    else:
        result = model(x=x)

    loss = criterion(result, gt)
    # print("loss: {}".format(loss))

    loss.backward()
    optimizer.step()


However during train(), it does it differently and my
x_after_bn_during_train = -1.0, 1.0

If you run the code you get:
------train()------
batchnorm params; running mean: tensor([0.4977]), running_var: tensor([0.5023])
before running bn: tensor([[[[0., 1.]]]])
after running bn: tensor([[[[-1.0000, 1.0000]]]]) --> I expected this to be ~ -0.7, 0.7
------eval()------
batchnorm params; running mean: tensor([0.4977]), running_var: tensor([0.5023])
before running bn: tensor([[[[0., 1.]]]])
after running bn: tensor([[[[-0.7022, 0.7087]]]])

This means that during training, it’s learning weights using -1, 1, while during evaluation, -0.7 and 0.7 are used…
So if a network has a lot of non-linearity activation functions such as relus and BNs actually had weights and biases, then wouldn’t evaluation (whether cross validation or test with eval() mode) fail miserably while training (with train() mode) reports good accuracy?

What I’m trying to get at is, during training mode what is the exact transformation done on inputs? and why would that be different from the transformation using the correct mean and variance of the data during eval() mode?

Is the solution to run forward pass multiple times with data to capture statistics of the data first without updating weights (with torch.no_grad()), then put all batchnorm layers in eval() mode during training (put the rest of the network in train()) AND validation/testing? It kind of seems backwards…

During training the biased variance is calculated from the batch and applied, while the unbiased is used to calculate the running stats.
Code to reproduce:

x = torch.tensor([0, 1.0]).view(1, 1, 2, 1)
var = x.var([0, 2, 3], unbiased=False)
mean = x.mean([0, 2, 3])

out = (x - mean[None, :, None, None]) / torch.sqrt(var[None, :, None, None])

Also, have a look at this manual batchnorm approach I’ve written some time ago.

I see, so then the main difference being the (n-1) vs (n), so with large enough data in each batch, the difference will lessen, right?

Thanks,

Yes, the difference should get smaller with an increasing number of samples.