Solution is to use mdl.train()
it uses batch statistics by itself:
Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation. The running estimates are kept with a default
momentum
of 0.1.If
track_running_stats
is set toFalse
, this layer then does not keep running estimates, and batch statistics are instead used during evaluation time as well.
https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html