Understanding Instance Normalization 2D with running mean and running var

Hi, recently I have been trying to convert StarGAN v1 from Pytorch to ONNX and they had an Instance normalization layer with track_running_stats=True. When I exported the model to ONXX it turned out that the exporter does not export the run mean/variance. Nevertheless, the onnx model still gives comparable results to the original model. I was thinking about why it can happen. Then I did the little experiment. I wanted to understand what is the difference between batch norm, instance norm with running mean/var and instance norm without the mean/var. So, I have initialized three layers with the same weights.

normTrue = nn.InstanceNorm2d(64, affine=True, track_running_stats=True)
normFalse = nn.InstanceNorm2d(64, affine=True)
input = torch.randn(10, 64, 128, 128)
bnorm = nn.BatchNorm2d(64, affine=True, track_running_stats=True)
w = torch.rand(64)
b = torch.rand(64)
m = torch.rand(64)
v = torch.rand(64)
with torch.no_grad():
    normTrue.weight=nn.Parameter(w)
    normTrue.running_mean=m
    normTrue.running_var=v
    normTrue.bias=nn.Parameter(b)
    bnorm.weight=nn.Parameter(w)
    bnorm.running_mean=m
    bnorm.running_var=v
    bnorm.bias=nn.Parameter(b)
    normFalse.weight=nn.Parameter(w)
    normFalse.bias=nn.Parameter(b)

It turned out that in the training mode the instance normalization with tracking running stats and without are acting the same.

with torch.no_grad():
    normoutTrue=normTrue(input).detach().cpu().numpy()
    normoutFalse=normFalse(input).detach().cpu().numpy()
    bnormout=bnorm(input).detach().cpu().numpy()
print(np.max(np.abs(normoutTrue-bnormout)))
print(np.max(np.abs(normoutTrue-normoutFalse)))
print(np.max(np.abs(bnormout-normoutFalse)))

0.05608654
0.0
0.05608654

But in the inference mode they are different, actually, instance norm with tracking running mean and variance is similiar to the batch norm

normTrue.eval()
normFalse.eval()
torch.onnx._export(normTrue,             # model being run
                           (torch.rand(10,64, 128, 128)),                   
                           "./norm.onnx") ;
bnorm.eval()
torch.onnx._export(bnorm,             # model being run
                           (torch.rand(10,64, 128, 128)),                   
                           "./bnorm.onnx") ;

with torch.no_grad():
    normoutTrue=normTrue(input).detach().cpu().numpy()
    normoutFalse = normFalse(input).detach().cpu().numpy()
    bnormout=bnorm(input).detach().cpu().numpy()
print(np.max(np.abs(normoutTrue-bnormout)))
print(np.max(np.abs(normoutTrue-normoutFalse)))
print(np.max(np.abs(bnormout-normoutFalse)))

9.536743e-07
6.2534113
6.2534113
So, I have a hard time understanding what is the need for normalizing by run mean/var in the inference stage if they were not used in the training phase and I was actually interested whether ONNX exporter ignores running mean/variances on purpose.

Would be grateful for any clarification,
Best Regards

1 Like

Instance norm is implemented using batch norm with a reshaped input (with the batch dimension folded into the channel dimension and an artificial size 1 batch dimension being prepended) and then the running stats (which would be of shaped batch * channel) are averaged over the “batch” part.
Once you accept this and use running stats, you can set the batch norm to train or not…

If your instance norm statistics are reasonably stable, the statistics used by the two possible modes “actual input stats” (if running stats is false or the mode is training and “collected running stats” (with running stats in eval mode) should be similar. The latter might be faster in inference because no statistics need to be gathered.

That said, the track_running_stats=True is rather not what people expect from instance norm, so I would recommend to not use it unless you have a specific reason to do so.

Best regards

Thomas

3 Likes

@tom Thanks for reaching out. Do I understand correctly that training with track_running_stats=True or False is the same, it is only different in the inference stage? Because if so then it will make sense that the outputs during the inference would be of worse quality than in the training mode.

Those are the results from the inference in the training/evalution mode


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

model= G.main
with torch.no_grad():
    outTrain = model(realInput)
    outTrain=to_numpy(outTrain)[0]
model.eval()
with torch.no_grad():
    outEval = model(realInput)
    outEval=to_numpy(outEval)[0]
    # Postprocess
    outEval=(outEval+1)/2
    outTrain =(outTrain+1)/2
plt.imshow(outTrain.transpose((1,2,0)))
plt.imshow(outEval.transpose((1,2,0)))

Then I guess it makes sense that ONNX exporter does not export runmean/runvariance.

1 Like

Yes, this is exactly my understanding as well.

Then I guess it makes sense that ONNX exporter does not export runmean/runvariance.

indeed.

Best regards

Thomas

1 Like

Hello! Do you solve the problem that the results of pytorch and onnx are different? I use the InstanceNorm3d and set the track_running_stats=False, but the onnx result and pytorch result are different when the results appear two decimal places. I am looking forward to your replay! Thank you!