What 'num_batches_tracked' in the new BN is for?

In PyTorch 0.4.1, the new added buffer ‘num_batches_tracked’ in BN can cause pretrained model incompatibility with old version like [SOLVED] Unexpected key(s) in state_dict: batches_tracked".

So I am wondering why we need this buffer?

1 Like

It’s used to update running_mean and running_variance.
They are used to normalize input samples from test/validation datasets.

Hi @crcrpar, could you please specify how this will affect the running_mean and running _variance? Thanks!

1 Like

any new comments on this ?

Hi, I think you can check this page to see what it is for:
github.com/pytorch/batchnorm.py

Just as its name implies, assuming you want to use torch.nn.BatchNorm2d (by default, with track_running_stats=True):

  1. When you are at training, the num_batches_tracked (starting from zero) will plus 1 for each mini-batch.
  2. If you didn’t specify momentum, which is used for calculating running mean and variance by exponential method, then it will use 1/num_batches_tracked as a factor to calculate them (i.e. cumulative method).

I guess it should be clear why we need this buffer.

2 Likes