How does one use the mean and std from training in Batch Norm?

I wanted to use the means, stds from training rather than batch stats since it seems if I use batch statistics my model diverges (as outline here machine learning - When should one call .eval() and .train() when doing MAML with the PyTorch higher library? - Stack Overflow). How does one do that?

I am asking since my model seems to have them be zero despite no training having been done yet:

Out[1]: BatchNorm2d(32, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
args.base_model.model.features.norm1.running_mean
Out[2]: 
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])

are these not saved in a ckpt after training? Should they have been saved?

Docs say they should have:

Also by default, during training this layer keeps running estimates of its computed mean and variance, which are then used for normalization during evaluation.

by running means are zero vectors… :confused: ?


related:

That’s expected, since the running_mean will be initialized with zeros and the running_var with ones.
Both are update during the training, so if no training was performed yet, they will keep their initial values.

If you don’t want to use the running stats but want to normalize the input activations with their batch stats, use track_running_stats=False.

are they zero even from a saved checkpoint?

I’ve been told that someone else’s running_mean is not zero from a saved checkpoint. Is this the issue when alternating evaluation and training modes in models? @ptrblck

after reading the BN code in detail and posts around here + the original paper the conclusion is here Inconsistent Batchnorm behavior in eval and training modes - #4 by Brando_Miranda in summary:

BN intended behaviour:

  • Importantly, during inference (eval/testing) running_mean, running_std is used (because they want a deterministic output and to use estimates of the population statistics).
  • During training the batch statistics is used but a population statistic is estimated with running averages. I assume the reason batch_stats is used during training is to introduce noise that regularizes training (noise robustness)

ref: [1502.03167] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift


So the main mystery is to figure out why my models were saved this way and their running averages from training removed.


meta-leanring context: Is there data leakage in the maml-omniglot example? · Issue #107 · facebookresearch/higher · GitHub

The running stats are restored from the checkpoint. If they were updated previously and stored with these updated values, loading the state_dict would restore these updated values.
If they were never updated and the model was used in model.eval() during training, their initial values would be stored and loaded.

@ptrblck do you have any guess how come my running mean are set to a value of zero? My understanding is that they should never have that value. In eval mode we use the running_mean. In “not tracking running mean” they should be None. So to I am so puzzled where to even start looking into my code where the running mean could have been set to zero before saving my checkpoint. Do you have an idea how this could happen? Have you ever seen such a weird thing where the checkpoint has zero in the running means?

args.mdl1.model.features.norm1.running_mean
Out[6]: 
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])
args.mdl1.model.features.norm1.running_var
Out[7]: 
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]

I don’t think this should ever be happening according to my 3hours or so of reading the docs and all the pytorch code in some detail.

Doing a search over my code shows the only places where reset_running_stats exists is in batch_norm and /Users/brando/anaconda3/envs/metalearning/lib/python3.9/site-packages/torch/nn/intrinsic/qat/modules/conv_fused.py (code I didn’t write and I’m not calling).

also, in the construction of the BN layer __init__ thats when it’s done (torch.nn.modules.batchnorm — PyTorch 1.10.0 documentation)…so my model somehow wasn’t tracking them during training but when I print the checkpoint the track_running_stats is True which puzzles me even more.

Out[8]: 
Learner(
  (model): ModuleDict(
    (features): Sequential(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm1): BatchNorm2d(32, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
      (relu1): ReLU()
      (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm2): BatchNorm2d(32, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
      (relu2): ReLU()
      (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm3): BatchNorm2d(32, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
      (relu3): ReLU()
      (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (conv4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (norm4): BatchNorm2d(32, eps=0.001, momentum=0.95, affine=True, track_running_stats=True)
      (relu4): ReLU()
      (pool4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (cls): Linear(in_features=800, out_features=5, bias=True)
  )
)

They should have these values after they are initialized. I would guess that your training might set the batchnorm layers or the entire model into .eval() mode so that the running stats are never updated and keep their initial values.

Check your code for .eval() calls (additionally also for self.training = False assignments) and see if that might be the issue.