track_running_stats is used to initialize the running estimates as well as to check if they should be updated in training (line of code).
The running estimates won’t be updated in eval:
bn = nn.BatchNorm2d(3)
for _ in range(10):
x = torch.randn(10, 3, 24, 24)
out = bn(x)
print(bn.running_mean)
print(bn.running_var)
> tensor(1.00000e-03 *
[-0.7753, 0.7027, -1.4181])
tensor([ 1.0015, 1.0021, 0.9947])
bn.eval()
for _ in range(10):
x = torch.randn(10, 3, 24, 24)
out = bn(x)
print(bn.running_mean)
print(bn.running_var)
> tensor(1.00000e-03 *
[-0.7753, 0.7027, -1.4181])
tensor([ 1.0015, 1.0021, 0.9947])