I usually alternate between training and testing in my code. This results in me calling
mdl.train() when coming back to the trianing loop and
mdl.test() when coming to the testing loop. However I’ve noticed that potentially this creates issues with how parameters are saved (e.g. batch norm, perhaps others) that depend on this eval, train state. Thus, it begs the question how these flags should be called if at all. If I call evaluation before checkpointing I believe this will delete my running average as mentioned here: How does one use the mean and std from training in Batch Norm?.
I think a deep copy of the model during evaluation should fix things - but that seems like it will kill my GPU memory (and perhaps my normal memory too).
Thus, what is the proper way to alternate between evaluation and training in PyTorch so that checkpoints are saved properly (e.g. the running stats of training are NOT deleted).
This is a challenge because I usually always run evaluation to see if the current model is better in validation than the previous one and then decide on THAT to save it or not - which requires me to do an evaluation before checkpointing - always.
Perhaps, I can always run it on train so that the running averages are saved correctly but since the actual validation value doesn’t matter it’s fine if the stats from train and val leak to the validation.