Running mean and running stats

I have observed that batch normalization parameters such as running mean and running stats get update after we do the forward pass through the model ( as in just after we do output=model(input) ).

So when we train the model and do the evaluation after every epoch is it recommended to put the model in eval mode and then do the inference or not?



When using batchnorm in training mode, the running stats are always updated yes.
You should be using the eval mode to use these stats and stop updating them when evaluating indeed.

1 Like

Thanks for the reply @albanD

So when we put the model in eval mode does it use the bn params: running_var, running_mean, gamma, and beta from the training phase?

Yes, once in eval mode, it uses the saves stats to compute the output.

1 Like

So if I want to update some other parameters using my evaluation loss (on validation set) I should first calculate the validation loss with model in eval mode so that running mean and vars are not updated and then shift to train mode to update the additional parameters right?

It depends if you want to update the statistics when doing these computations or not.

But if we update these statistics during evaluation we will be kind of providing the information of validation set to train set which would not be recommended I believe.

Ho if you’re doing regular cross validation, indeed you don’t want to do that.
But that might depend on your application.

1 Like

@albanD a quick question related to running mean var while the model is in train mode. So if my model is in train mode, does pytorch uses the running mean and var only of the current mini batch or does it calculates running mean and var based on previous batches and the current mini batch as well?

You might want to read the doc to get all the details: BatchNorm2d — PyTorch 1.8.0 documentation

But the short version is that during training, it computes the stats on the current batch and use that to compute the output and update the running stats.
During evaluation, it uses the saved running stats to compute the output.

1 Like