According to the documentation for batch_norm, running_mean and running_var are updated according to the following equation:
running_mean = momentum * batch_mean + (1 - momentum) * running_mean
running_var = momentum * batch_var + (1 - momentum) * running_var
Consider the following input:
a = torch.tensor([[-1., -1.],
[-1., -1.],
[-1., 1.]])
mean = torch.ones((2))
var = torch.zeros(2)
Running torch.nn.functional.batch_norm(a, mean, var, training = True )
updates the var to tensor([0.0000, 0.1333])
which does not match the formula above.
According the the above formula the result should be equivalent to:
batch_var = torch.var(a, unbiased= False, axis = 0)
print(0.1 * batch_var + (1 - 0.1) * var)
which returns tensor([0.0000, 0.0889])
.
EDIT: If we do torch.var(a, unbiased= True, axis = 0), we get the same result as applying batch_norm. However, according to the documentation “The standard-deviation is calculated via the biased estimator, equivalent to torch.var(input, unbiased=False).”