BatchNorm with track_running_stats=True versus False leads to major performance difference

I have a model that I trained for colorization (Imagenet), containing BatchNorm layers with the parameter track_running_stats=False. The results were very good, but I can’t convert that model to an ONNX one because it’s not supported at the moment (I get an error and from what I can tell, it’s not possible at the moment).
So I loaded the model weights to a model with BatchNorm (track_running_stats=True) using strict=False and resumed training, but found that the performance of the model significantly decreased even after many epochs.

Q1. Can someone please explain why making this change impacts the performance so much? I know it has to do with the mean and variance at the batch level. When track_running_stats=False it does not track these global values, correct?

Q2. Also, is there a way to convert a PyTorch model to ONNX even when using BatchNorm with track_running_stats=False?

  1. I would assume the training results would not see any impact, since the batch statistics would be used to normalize the activations in both cases (track_running_stats=True/False).
    The validation results might be different, since the running stats might not represent the batch stats well (e.g. if the data comes from different distributions etc.).
1 Like

Thank you for the reply, @ptrblck !
Indeed, this happens on test images, which can be quite different from the training dataset.

I will try to include some augmentation on the train data, so that it resembles the test data better.