Batch Norm Running Mean and Average Included in torch.save?

Hi,

Short version: Are Batch Norm running mean and average included when using torch.save()?

Long version: I have recently discovered an issue where I had constantly growing parameters when I was training the model. This was no issue for the training, but it actually gave a lower score when using model.eval() because the running mean and variance of the batch norm were never quite in sync with the newest model parameters.

Now my thought was when I use torch.save() and load the model for inference, from my understanding, if those “delayed” running mean/var will get saved then I will still have the same issues when I do inference with torch.load().

This could be resolved by training with 0 learning rate for a while so that the bn values catch up and then saving the model, but this seems very hacky. It’s a long shot but if anyone ever encountered a similar issue I would be glad about additional info because I am actually riddled what would be the “proper” way here.

they are included in the state_dict.

2 Likes