In PyTorch 0.4.1, the new added buffer ‘num_batches_tracked’ in BN can cause pretrained model incompatibility with old version like [SOLVED] Unexpected key(s) in state_dict: batches_tracked".
So I am wondering why we need this buffer?
In PyTorch 0.4.1, the new added buffer ‘num_batches_tracked’ in BN can cause pretrained model incompatibility with old version like [SOLVED] Unexpected key(s) in state_dict: batches_tracked".
So I am wondering why we need this buffer?
It’s used to update running_mean
and running_variance
.
They are used to normalize input samples from test/validation datasets.
Hi @crcrpar, could you please specify how this will affect the running_mean and running _variance? Thanks!
any new comments on this ?
Hi, I think you can check this page to see what it is for:
github.com/pytorch/batchnorm.py
Just as its name implies, assuming you want to use torch.nn.BatchNorm2d
(by default, with track_running_stats=True
):
num_batches_tracked
(starting from zero) will plus 1 for each mini-batch.momentum
, which is used for calculating running mean and variance by exponential method, then it will use 1/num_batches_tracked
as a factor to calculate them (i.e. cumulative method).I guess it should be clear why we need this buffer.