Changing batch norm statistics doesn't change result

Hi,

During inference, I switch the batch norm modules to be training mode (batchnorm.train())
And I manually change variance in every inference step by this code.

for _, m in enumerate(encoder.modules()):
    if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
        print("Running Var", m.running_var)
        m.train()
        m.running_var = 100 * m.running_var

I print out the batch norm parameter’s every inference step, and variance increases (and be inf at the late inference steps).

However, the output of the network doesn’t change. My case is auto encoder that receives and reconstructs images. The reconstructed images wouldn’t change as the batch norm parameters change. I’m wondering why this happens.

I use PytorchLighting for the model, is it because of the library?

Update:

I try m.eval() instead and the results make sense now.

I wonder if it is because we can’t change m.running_mean, m.running_var, m.momentum during m.train().

But if so, why m.running_var changes as I mentioned above that m.running_var can still increase over inference steps to be inf at the end.

for _, m in enumerate(encoder.modules()):
    if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
        print("Running Var", m.running_var)
        m.eval()
        m.running_var = 100 * m.running_var

The running stats will be updated using the batch stats in train() mode. In eval() mode they will be used to normalize the input and not updated anymore.

You are manually increasing it, so it’s expected that the values overflow at one point.

1 Like

Hi,
Thank you for your reply.
What I mean is that when I set batch stats (running_mean, running_var, momentum) manually, the output won’t change in .train() mode.

For example, if I set running_var = inf in .eval() mode, the output vector is [beta, beta, beta, …, beta], where beta is the affine parameter in batch norm transformation.

But if I switch to .train() mode, even I set batch stats manually, running_var = inf, the results are something like [x1, x2, x3,…,xM], but it should be beta vector if the batch stats really change to what I set.

Does this mean we cannot change batch stats manually in .train() mode?

This is expected since the running stats will only be used to normalize the input during eval(). During train() the running stats will be updated but not used to normalize the input as already explained. Maybe looking at this manual implementation clarifies the behavior.

Thank you so much.
The implementation you give me is very clear and helpful.