Use running averages of batch normalization during training as well?

In this video, running batch normalization is discussed as an alternative to regular batch normalization, to eliminate the training–inference disparity and improve model performance. This works simply by using the running averages, not only during inference, but during training as well. Is there some way to do this using the BatchNorm1d and BatchNorm2d layers in PyTorch, or do I need to roll my own module to do that?

If I understand your use case correctly, you would like to:

  • calculate the batch stats
  • update the running stats with the current batch stats using the running average
  • normalize the input using the updated running stats instead of the batch stats

If so, then I don’t think you can use the nn.BatchNormXd layers to do so and would need to implement a custom layer.
You could use my manual implementation as a template and change the logic in the forward method.

1 Like