In PyTorch 0.4.1, the new added buffer ‘num_batches_tracked’ in BN can cause pretrained model incompatibility with old version like [SOLVED] Unexpected key(s) in state_dict: batches_tracked".
So I am wondering why we need this buffer?
In PyTorch 0.4.1, the new added buffer ‘num_batches_tracked’ in BN can cause pretrained model incompatibility with old version like [SOLVED] Unexpected key(s) in state_dict: batches_tracked".
So I am wondering why we need this buffer?
It’s used to update running_mean and running_variance.
They are used to normalize input samples from test/validation datasets.
Hi @crcrpar, could you please specify how this will affect the running_mean and running _variance? Thanks!
any new comments on this ?
Hi, I think you can check this page to see what it is for:
github.com/pytorch/batchnorm.py
Just as its name implies, assuming you want to use torch.nn.BatchNorm2d (by default, with track_running_stats=True):
num_batches_tracked (starting from zero) will plus 1 for each mini-batch.momentum, which is used for calculating running mean and variance by exponential method, then it will use 1/num_batches_tracked as a factor to calculate them (i.e. cumulative method).I guess it should be clear why we need this buffer.