Short version: Are Batch Norm running mean and average included when using
Long version: I have recently discovered an issue where I had constantly growing parameters when I was training the model. This was no issue for the training, but it actually gave a lower score when using
model.eval() because the running mean and variance of the batch norm were never quite in sync with the newest model parameters.
Now my thought was when I use torch.save() and load the model for inference, from my understanding, if those “delayed” running mean/var will get saved then I will still have the same issues when I do inference with
This could be resolved by training with 0 learning rate for a while so that the bn values catch up and then saving the model, but this seems very hacky. It’s a long shot but if anyone ever encountered a similar issue I would be glad about additional info because I am actually riddled what would be the “proper” way here.