Population statistics for batchnorm instead of running average

I would like to implement batch normalization as originally described in the paper (https://arxiv.org/pdf/1502.03167.pdf). After the network finishes training, they take one more pass through the dataset to estimate the population mean and variance (step 10 in algorithm 2). Pytorch instead uses the common trick of keeping a running average to estimate these statistics while training.

Is there an easy way to implement this? Basically I need to have access to the mini batch mean and variance computed within each batch normalization layer.

Is there code for this already?


You can

  1. add forward_pre_hooks to each BN layer, which tracks the means and vars
  2. iterate through the data
  3. replace running_mean and running_var with your tracked aggregated ones.

I think you might see if track_running_stats in BatchNormXd (torch 0.4 and above) does what you want.
You can probably copypaste the relevant code if you have an older version.

Best regards


Thank you SimonW! I think that will work. Just coding it up now.

Re Tom: setting track_running_stats=False will use the mini batch statistics at test time (as it does at training time). However this is different than estimating the population statistics and using them at test time. Using mini batch statistics at test time has many undesirable properties (including the prediction on an image will depend on other images in the mini batch).

If someone still needs this, we wrote up a small script to do this: