nn.BatchNorm2d cannot correctly restart training

Hi, I’m trying to restart training on a Resnet with several batch norm layers with track_running_stats=True. I find that my model accuracy gets worse significantly after starting the retraining (model accuracy does not get worse if I set track_running_stats=False). I checked the weight, bias, running_mean, running_var, and num_batches_tracked of the batchnorm layer, and they all are loaded correctly. However, I noticed that during the first training, num_batches_tracked increments to 82 (the number of batches in my first epoch). During the second training, num_batches_tracked=82 is loaded, but then after the first epoch of continued training, num_batches_tracked becomes 164.

My guess is that the 82 internal saved batch results were lost during saving / loading, so then upon starting retraining, 82 more batches are saved, but during this process somehow my network weights also become messed up.

Is there a way for me to retain the previous progress while still keeping track_running_stats=True?

edit: I should clarify that eval seems to work fine.