Hi, I spotted an unintended behaviour or possibly a bug due to this fix
[BatchNorm] Unexpected behaviour with track_running_stats · Issue #37823 · pytorch/pytorch (github.com)
but before opening an issue I would like to make sure I am understanding this correctly.
In the past (before that commit), we could change the behaviour of
BatchNorm after the module was created. That is, we could either run it in train mode (stats computed on the batch) or in eval mode (using
running_*) depending on the value of
.track_running_stats which could be changed even after module creation with:
for m in self.net.modules(): if isinstance(m, torch.nn.BatchNorm2d): m.track_running_stats = False
Indeed, I have used this trick in the past to always compute batch stats even in eval mode and it worked.
The newer commit doesn’t allow this behaviour. That’s because it will set the batch in “train” mode only if
.running_* are set to None, which only happens when
track_running_stats is set to
False when first initializing the batch module, and not later on like in the snippets above.
The line that establishes whether it’s train or eval mode is this (in the earlier commit):
self.training or not self.track_running_stats,
and this (in the newer commit):
(self.running_mean is None) and (self.running_var is None)
As you can see, in the newer commit, setting
False after initializing, the module won’t set the batch in train mode (
running_* is not None), but it will set it in eval mode using the default
_variance: 0, and 1, creating disastrous results.
Quick fix: set
track_running_stats=False when the module is initialized, not later.
Is this worth fixing?
ps: I am using PyTorch 1.7.0