Source of error in Batch Normalization


I am testing my understanding of torch.nn.BatchNorm2d but there is a slight error in numbers that I have failed to resolve.

I thought what the Batch Normalization in training mode without any update shall do

(x - E[x]) / sqrt(var(X) + 1e-5)

which in python code

(x - x.mean([0,2,3], keepdim=True)) / (x.var([0,2,3], keepdim=True) + 1e-5).sqrt()

But when I actually compare it to torch.nn.BatchNorm2d with the following code,

import torch
import torch.nn as nn

bn = nn.BatchNorm2d(5)
x = torch.randn(500,5,2,2)

bn_x = (x - x.mean([0,2,3], keepdim=True)) / (x.var([0,2,3], keepdim=True) + 1e-5).sqrt()
print((bn(x) - bn_x).abs().max().item())

I constantly get a nonzero difference in range of ~1e-3.
The error decreases as the B, H, W numbers get higher.

I tried Bessel’s correction with different numbers but it did not work.
I tried removing 1e-5 (eps parameter in nn.BatchNorm2d) but it did not work either.
It might be because of the floating point precision, but the error seems quite high.

Could you point me to the source of this error?

I would like an answer to the following questions:

  • Is my understanding of nn.BatchNorm2d correct?
  • What should I change in my formula / code to remove the difference?

I guess the difference domes from the var calculation, which should apparently be unbiased.
I’ve written a manual implementation some time ago here so feel free to compare your code to this one.

Thank you so much for your answer. It was more than perfect.

1 Like