I want to fix the running_mean and running_var in BN during some training iteration.
So I set track_running_stats=False during training,but why running_mean still update.
bn = nn.BatchNorm2d(3)
for _ in range(10):
x = torch.randn(10, 3, 24, 24)
out = bn(x)
print(bn.running_mean)
> tensor([ 0.0031, -0.0034, -0.0066])
print(bn.running_var)
> tensor([1.0027, 0.9977, 1.0081])
print(bn)
> BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
bn.track_running_stats=False
print(bn)
> BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
for _ in range(10):
x = torch.randn(10, 3, 24, 24)
out = bn(x)
print(bn.running_mean)
> tensor([ 0.0014, -0.0021, -0.0048])
print(bn.running_var)
> tensor([1.0036, 0.9980, 0.9983])
track_running_stats will be used to initialize the running stats in the __init__ method as seen here.
If you don’t want to update these stats, you should call bn.eval() and reset it afterwards via bn.train() to train them further.
Can you give an example code to illustrate the sequence of the process that you mentioned?
I am having the problem of using Batch normalization. The performance of my model drops significantly when running at inference mode (net.eval()).
I am looking forward to how to properly train and run the inference my model with Batch normalization with track_running_stats=False. Could you give the guidance.
If you set track_running_stats=False, the batchnorm layer will always use the batch statistics to normalize the activations, so there is no need to call train() or eval() on this layer anymore.
@ptrblck , I used the track_running_stats=False. But changing batch_size during test time changes results dramatically (say using batch_size=32 gives 99% accuracy and using batch_size=1 gives 81% accuracy). How to handle that?
This is expected since with track_running_stats=False the normalization will be performed based on the batch statistics and is thus dependent on the batch size, which is generally not a desired property for model inference. The best you could do is most likely trying to keep the batch size in the “optimal” setup.
Okay, Actually my network architecture is a Siamese network(Metric Learning) inspired by some paper where they have set track_running_stats=False. Now due to the problem of batch size during inference, I tried with track_running_stats=True, my same model started heavily overfitting even if I increased the Dropout and weight decay.
Does that mean running mean and variance are unsuitable for my test data?
Any advice? It will be a great help.
That’s hard to tell, but depending on the input data and its stats the running stats of batchnorm layers might indeed not converge to the batch statistics.
Thank you for your response. So do you suggest trying LayerNorm or InstanceNorm in such a case? I know, I can just try and see the results but wanted to know your intuition.