BatchNorm2d bug/unintended behaviour

Hi, I spotted an unintended behaviour or possibly a bug due to this fix
[BatchNorm] Unexpected behaviour with track_running_stats · Issue #37823 · pytorch/pytorch (

but before opening an issue I would like to make sure I am understanding this correctly.

In the past (before that commit), we could change the behaviour of BatchNorm after the module was created. That is, we could either run it in train mode (stats computed on the batch) or in eval mode (using running_*) depending on the value of .track_running_stats which could be changed even after module creation with:

for m in
    if isinstance(m, torch.nn.BatchNorm2d):    
         m.track_running_stats = False

Indeed, I have used this trick in the past to always compute batch stats even in eval mode and it worked.

The newer commit doesn’t allow this behaviour. That’s because it will set the batch in “train” mode only if .running_* are set to None, which only happens when track_running_stats is set to False when first initializing the batch module, and not later on like in the snippets above.

The line that establishes whether it’s train or eval mode is this (in the earlier commit): or not self.track_running_stats,
and this (in the newer commit):
(self.running_mean is None) and (self.running_var is None)

As you can see, in the newer commit, setting self.track_running_stats to False after initializing, the module won’t set the batch in train mode (running_* is not None), but it will set it in eval mode using the default running_mean and _variance: 0, and 1, creating disastrous results.

Quick fix: set track_running_stats=False when the module is initialized, not later.

Is this worth fixing?

ps: I am using PyTorch 1.7.0

Could you explain, what exactly should be fixed a bit more?

This was indeed the case, but you could also use the same behavior by calling .train() and .eval() on the batchnorm layers.
I guess your use case might be a bit more special in a way, which used the “hack” of changing the track_running_stats attribute?