Batch Normalization Momentum Meaning

Batch normalization has a momentum parameter in PyTorch, but I’m not entirely sure what it does. I have seen batch normalization with a decay parameter before. Is momentum doing the same thing?

In this 3rd party TensorFlow tutorial, batch norm’s moving average is updated according to 1-decay:

decay = 0.999 # use numbers closer to 1 if you have more data
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))

Is PyTorch’s momentum equivalent to the (1-decay) term above?


Yes. Also, this is a bug, but since it’s a breaking change we will leave it there. I do have a PR pending to add a warning.