How does PyTorch implement batch normalization?

Sample code to explain what I mean:

from torch.nn.functional import batch_norm
from torch import from_numpy
import numpy as np

M = np.random.randn(1, 3, 16, 16).astype(np.float32)
m, v, w, b = [np.random.rand(3).astype(np.float32) for _ in range(4)]
M_torch = batch_norm(*[from_numpy(x) for x in [M, m, v, w, b]]).numpy()
m, v, w, b = [p.reshape(1, 3, 1, 1) for p in [m, v, w, b]]
M_np = w * (M - m) / np.sqrt(v + 1e-05) + b
print(np.max(np.abs(M_np - M_torch)))

The error is quite large. So why isn’t the way I’ve implemented batch normalization using NumPy exactly identical to PyTorch’s method? Note that the epsilon I used is large, 1e-05. But if I use a smaller epsilon like np.finfo(np.float32).eps the results differ even more. So clearly PyTorch is doing something slightly different and I don’t think it can be chalked up to floating point associativity.

This script shows the internal batchnorm approach used in PyTorch.

Did you compare the implementation shared in my post or is this an unhelpful chat bot post?

1 Like

Your example does exactly the same calculation as my code. Here is what I get when I print the differences with 10 decimals precision:

print('%12.10f' % (out1 - out2).abs().max())

0.0000002384
0.0000019073
0.0000009537
0.0000019073
0.0000004768
0.0000019073

So your code also has the same deviations. Why is that? And why is epsilon 1e-5 and not a smaller value?

That’s expected since different orders of mathematical operation cause expected small differences due to the limited floating point precision as seen e.g. here:

x = torch.randn(100, 100)
s1 = x.sum()
s2 = x.sum(0).sum(0)

print(s1 - s2)
# tensor(1.5259e-05)