Batch Normalization Layer: saving and loading the running stats

When using the function torch.save(model.state_dict(), PATH) and subsequently loading the model using model.load_state_dict(torch.load(PATH)), what happens to the running mean and variance of a batch normalization layer? Are they saved and loaded with the same values, or are they set to default when a model is initialized using the saved state_dict?

1 Like

They are saved and loaded as seen here:

bn = nn.BatchNorm2d(3)
print(bn.running_mean)
> tensor([0., 0., 0.])
print(bn.running_var)
> tensor([1., 1., 1.])

out = bn(torch.randn(10, 3, 24, 24))
print(bn.running_mean)
> tensor([0.0009, 0.0018, 0.0004])
print(bn.running_var)
> tensor([1.0009, 1.0015, 1.0022])

torch.save(bn.state_dict(), 'tmp.pt')

bn = nn.BatchNorm2d(3)
print(bn.running_mean)
> tensor([0., 0., 0.])
print(bn.running_var)
> tensor([1., 1., 1.])

bn.load_state_dict(torch.load('tmp.pt'))
print(bn.running_mean)
> tensor([0.0009, 0.0018, 0.0004])
print(bn.running_var)
> tensor([1.0009, 1.0015, 1.0022])
2 Likes

Thanks for your response!