Sample code to explain what I mean:
from torch.nn.functional import batch_norm
from torch import from_numpy
import numpy as np
M = np.random.randn(1, 3, 16, 16).astype(np.float32)
m, v, w, b = [np.random.rand(3).astype(np.float32) for _ in range(4)]
M_torch = batch_norm(*[from_numpy(x) for x in [M, m, v, w, b]]).numpy()
m, v, w, b = [p.reshape(1, 3, 1, 1) for p in [m, v, w, b]]
M_np = w * (M - m) / np.sqrt(v + 1e-05) + b
print(np.max(np.abs(M_np - M_torch)))
The error is quite large. So why isn’t the way I’ve implemented batch normalization using NumPy exactly identical to PyTorch’s method? Note that the epsilon I used is large, 1e-05. But if I use a smaller epsilon like np.finfo(np.float32).eps
the results differ even more. So clearly PyTorch is doing something slightly different and I don’t think it can be chalked up to floating point associativity.