I would like to implement batch normalization as originally described in the paper (https://arxiv.org/pdf/1502.03167.pdf). After the network finishes training, they take one more pass through the dataset to estimate the population mean and variance (step 10 in algorithm 2). Pytorch instead uses the common trick of keeping a running average to estimate these statistics while training.
Is there an easy way to implement this? Basically I need to have access to the mini batch mean and variance computed within each batch normalization layer.
Thank you SimonW! I think that will work. Just coding it up now.
Re Tom: setting track_running_stats=False will use the mini batch statistics at test time (as it does at training time). However this is different than estimating the population statistics and using them at test time. Using mini batch statistics at test time has many undesirable properties (including the prediction on an image will depend on other images in the mini batch).