Model.eval() gives incorrect loss for model with batchnorm layers

Lowering the momentum might help in situations, where the data is quite noisy.
So it’s good to hear you got reasonable results for your validation set by changing it! :slight_smile:

In the example code, you were feeding samples from two domains, so that the running estimates took values of their average (which is out of domain now).

Yup, I understood what you wrote earlier. I didn’t succeed in replicating my actual issue with my example code above.

How to do this? can we do it with “with torch.no_grad()?”

Yes, you can call model.train() and perform some forward passes in a torch.no_grad() block.

1 Like

Thanks for your solution. Recently I tried MnasNet pretrained model that has this problem. The momentum in that model is small (~3e-4). Beware of that if you are using that model.

Hi, I had similar issues and one thing I realized was that I defined one batch-norm layer and I used it after every layer. This might be the error that you might be making

For example:
self.batch_norm_hidden = nn.BatchNorm1d(num_hidden_nodes)

and then later:

for layer in layers:
x = layer(x)
x = self.activation_fn(x)
x = self.batch_norm_hidden(x)

This is obviously wrong as the same batch_norm_hidden is used everywhere. You need to define a new batch_norm for every layer (Otherwise the running stats are shared across the layers)

Also I think the momentum for the batchnorm layer is defined differently compared to the optimizers. It might be helpful to check the description as to how new batch_statistics are weighted

I had a similar problem, and setting track_running_stats=False fixed it for me (UNet, no dropout, just batchorm). But I still don’t understand why… Firstly, does setting that flag in any way affect the training, or only the evaluation process?
And secondly, I tried (for debugging purposes) training and validating (and evaluating after training) on the same single image. And also on the same single batch of 8 images. In both cases the problem (discrepancies between validation and training loss, or between network outputs under net.eval() vs. under net.train()) persisted.
But if I understand your explanation for the use of that track_running_stats correctly, then in my case, it should not make a difference. Since all the data comes from the same distribution. As it’s literally the same in all stages (train, val and test)… So why does the flag still seem to help in my case?

1 Like

This argument affects the validation, as no running stats would be calculated and all validation inputs will be normalized using the batch statistics.

The running stats will be updated using the momentum as described in the docs, so you would probably need more forward passes to let the running stats converge towards the batch stats.

Thanks for the answer.
Just to be sure, during training, when running stats are computed, it feels that with a single batch and momentum 0.1 it should definitely converge to the batch statistics after 100 epochs… Right?
And it still hasn’t in my case (Generally, the discrepancies do become smaller, but we’re still talking for example a Dice score of 0.6 vs 0.4 or sometimes more…)
So I wonder, otherwise perhaps I have another issue in the code, would you generally expect it to take so long (for convergence to batch statistics)?

Yes, I would assume that the running stats converge after 100 iterations as is also shown here:

x = torch.randn(1, 3, 100, 100) * 5. + 7.
bn = nn.BatchNorm2d(3)

mean, var = x.mean([0, 2, 3]), x.var([0, 2, 3])

for _ in range(100):
    out = bn(x)

bn_mean, bn_var = bn.running_mean, bn.running_var

print('sample mean {}\nsample var {}'.format(mean, var))
> sample mean tensor([7.0567, 7.0104, 6.9218])
  sample var tensor([24.7264, 24.9284, 24.9140])

print('bn mean {}\nbn var{}'.format(bn_mean, bn_var))
> bn mean tensor([7.0565, 7.0102, 6.9216])
  bn vartensor([24.7258, 24.9278, 24.9134])


It depends a bit on your use case. Note that the intermediate activations are not static, since the parameters are updated in each iteration. This could also mean that the stats are changing and that the bn layers are tracking these changes, so you cannot directly assume that x iterations will make the stats converge perfectly.

1 Like