Use running_mean instead of batch mean when calculating variance in BatchNorm2d at train time

Is there an easy way to let BatchNorm2d at train time calculate the variance of a batch w.r.t. the running_mean instead of the mean of that batch? That is, to use variance = mean(x - running_mean) instead of mean(x - mean(x)^2).

I need this feature for batch_size>1. However, I expect this to be generally very useful since it would allow variance estimates in BatchNorm even when training with batch_size=1.

Many thanks in advance!

You could take a look at my manual batchnorm implementation and manipulate it as you wish.
I would also recommend to try scripting your module afterwards to allow e.g. nvFuser to code-gen the (fused) kernels to speed it up again.

Thanks, Piotr. That was exactly what I was looking for!

1 Like