Proper way to call mdl.eval() and mdl.train() so that checkpointing saves all parameters?


I usually alternate between training and testing in my code. This results in me calling mdl.train() when coming back to the trianing loop and mdl.test() when coming to the testing loop. However I’ve noticed that potentially this creates issues with how parameters are saved (e.g. batch norm, perhaps others) that depend on this eval, train state. Thus, it begs the question how these flags should be called if at all. If I call evaluation before checkpointing I believe this will delete my running average as mentioned here: How does one use the mean and std from training in Batch Norm?.

I think a deep copy of the model during evaluation should fix things - but that seems like it will kill my GPU memory (and perhaps my normal memory too).

Thus, what is the proper way to alternate between evaluation and training in PyTorch so that checkpoints are saved properly (e.g. the running stats of training are NOT deleted).

This is a challenge because I usually always run evaluation to see if the current model is better in validation than the previous one and then decide on THAT to save it or not - which requires me to do an evaluation before checkpointing - always.

Perhaps, I can always run it on train so that the running averages are saved correctly but since the actual validation value doesn’t matter it’s fine if the stats from train and val leak to the validation.


Yes, this is the proper way.

No, that’s not the case and unfortunately you didn’t follow up in your other thread.
Saving the state_dict doesn’t have any known issues regarding previous calls to model.train() or model.eval(), so my guess is still that you might be manually deactivating the running stats updates.
If that’s not the case, please post a minimal executable code snippet which would reproduce the issue.