Batch Normalization Layer: saving and loading the running stats

When using the function torch.save(model.state_dict(), PATH) and subsequently loading the model using model.load_state_dict(torch.load(PATH)), what happens to the running mean and variance of a batch normalization layer? Are they saved and loaded with the same values, or are they set to default when a model is initialized using the saved state_dict?

1 Like

They are saved and loaded as seen here:

bn = nn.BatchNorm2d(3)
print(bn.running_mean)
> tensor([0., 0., 0.])
print(bn.running_var)
> tensor([1., 1., 1.])

out = bn(torch.randn(10, 3, 24, 24))
print(bn.running_mean)
> tensor([0.0009, 0.0018, 0.0004])
print(bn.running_var)
> tensor([1.0009, 1.0015, 1.0022])

torch.save(bn.state_dict(), 'tmp.pt')

bn = nn.BatchNorm2d(3)
print(bn.running_mean)
> tensor([0., 0., 0.])
print(bn.running_var)
> tensor([1., 1., 1.])

bn.load_state_dict(torch.load('tmp.pt'))
print(bn.running_mean)
> tensor([0.0009, 0.0018, 0.0004])
print(bn.running_var)
> tensor([1.0009, 1.0015, 1.0022])
3 Likes

Thanks for your response!

Hi, Could you please help to identify why batch norm running stats are not loaded properly for RNet1 below. But they are properly loaded for RNet. “latest_clipRN50_fine_tune_bn.pkl” is a finetuned model with architecture MyNet.

import torch
import clip
import torch.nn as nn

class MyNet(nn.Module):
    def __init__(self, my_pretrained_model):
        super(MyNet, self).__init__()
        self.pretrained = my_pretrained_model
        self.new_fc_head = nn.Linear(1024, 1000).to(torch.float16)
    
    def forward(self, x):
        ft = self.pretrained(x)   
        x = self.new_fc_head(ft)
        return x, ft
        

if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
    
    clip_model, preprocess  = clip.load("RN50", device=device, jit=False)
    model = clip_model.visual
    RNet = MyNet(my_pretrained_model=model).to(device)
    ckpt = torch.load('latest_clipRN50_fine_tune_bn.pkl', map_location=torch.device(device))
    RNet.load_state_dict(ckpt, strict=False)
    RNet.eval()
    
    RNet1, preprocess1  = clip.load("RN50", device=device, jit = False)
    RNet1.load_state_dict(ckpt, strict=False)    
    RNet1.eval()

    print(torch.equal(RNet.pretrained.bn1.running_var , ckpt['pretrained.bn1.running_var']))
    print(torch.equal(RNet1.visual.bn1.running_var , ckpt['pretrained.bn1.running_var']))
    print(torch.equal(RNet1.visual.bn1.running_var , RNet.pretrained.bn1.running_var))

    print(torch.equal(RNet.pretrained.conv1.weight , ckpt['pretrained.conv1.weight']))
    print(torch.equal(RNet1.visual.conv1.weight , ckpt['pretrained.conv1.weight']))
    print(torch.equal(RNet1.visual.conv1.weight , RNet.pretrained.conv1.weight))

True
False
False
True
True
True

Why are you using strict=False? This will ignore missing or unexpected keys and is most likely related to your issue.