Batchnorm: how to stop tracking the running stats but use the stats of a batch?

Hi,

I have a question about controlling the batchnorm.

As I know, if we apply eval() mode,
the batchnorm stops tracking the running stats and uses the tracked running stats that is registered in the buffer.

However, I would like to stop tracking running stats but use the stats of a given batch in the forward pass.

Is there a method for this?

Thanks.

If you don’t want to use the running stats at all, use track_running_stats=False during the initialization of the batchnorm layer and the input stats will always be used.

Thank you for your quick response!

I am fine-tuning the pretrained model that already has the tracked running stats,
and I want to make sure that the tracked running stats do not change at this time.

I don’t know if you still want to use the running stats later or not. In the former case, call model.bn.train() to use the input activation stats and reload the state_dict on these layers afterwards (or manually reset the running stats). In the latter case, you should be able to replace the pretrained batchnorm layers with new ones using track_running_stats=False and load the affine parameters only (if they are used in the pretrained model).

2 Likes

Thanks for your kind explanation.
As your solution, I will manually adjust the batchnorm params or running stats.

Thank you!