Batchnorm forward and backward - updates of statistics

Hi.
I tried to look at the C code and in the code in the docs, but still didn’t figure it out,
generally BatchNorm aggregates the mean and the variance of the samples it sees in train mode and then in test mode just uses them, my question is does the aggregation, i.e the updates of the mean and the variance according (in the current forward iteration) to what just was seen are done in the forward path or as it more conventional in the backward path, i.e only after backward and step functions are activated the internal statistics include the current iteration statistics ?

Thank you in advance for clarification.

The running stats are updated during the forward pass as seen here:

bn = nn.BatchNorm2d(3)
x = torch.randn(16, 3, 24, 24)
print(bn.running_mean)
> tensor([0., 0., 0.])

print(bn.running_var)
> tensor([1., 1., 1.])

out = bn(x)
print(bn.running_mean)
> tensor([ 0.0002, -0.0007,  0.0006])

print(bn.running_var)
> tensor([0.9991, 1.0022, 0.9988])
1 Like