In this video, running batch normalization is discussed as an alternative to regular batch normalization, to eliminate the training–inference disparity and improve model performance. This works simply by using the running averages, not only during inference, but during training as well. Is there some way to do this using the BatchNorm1d
and BatchNorm2d
layers in PyTorch, or do I need to roll my own module to do that?