I have a model that I trained for colorization (Imagenet), containing BatchNorm layers with the parameter track_running_stats=False. The results were very good, but I can’t convert that model to an ONNX one because it’s not supported at the moment (I get an error and from what I can tell, it’s not possible at the moment).
So I loaded the model weights to a model with BatchNorm (track_running_stats=True) using strict=False and resumed training, but found that the performance of the model significantly decreased even after many epochs.
Q1. Can someone please explain why making this change impacts the performance so much? I know it has to do with the mean and variance at the batch level. When track_running_stats=False it does not track these global values, correct?
Q2. Also, is there a way to convert a PyTorch model to ONNX even when using BatchNorm with track_running_stats=False?