How to change the BatchNormalization values?

Hi, all.

I want to move the BN values ​​of model_1 to model_2.

The test results before and after changing the BN value are the same. So I think something went wrong somewhere, but I don’t know what the problem is.

def save_bn_stats(model):
    model_encoder_dict = model.encoder.state_dict()

    cur_encoder_dict = OrderedDict()
    for key, value in model_encoder_dict.items():
        if 'bn' in key:
            cur_encoder_dict[key] = model_encoder_dict[key]

    return encoder_dict

def change_bn_stats(model, new_bn_stat):
    model_dict = model.state_dict()

    for key, value in new_bn_stat.items():
        model_dict[key] = copy.deepcopy(value)  

    model.load_state_dict(model_dict, strict=True)

def validate(model):
    model.eval()
    # validation code ...

validate(model_2)
bn_stat_1 = save_bn_stats(model_1)
change_bn_stats(model_2, bn_stat_1)
validate(model_2)

Your code seems to work for me:

def save_bn_stats(model):
    model_encoder_dict = model.state_dict()

    cur_encoder_dict = OrderedDict()
    for key, value in model_encoder_dict.items():
            cur_encoder_dict[key] = model_encoder_dict[key]

    return cur_encoder_dict

def change_bn_stats(model, new_bn_stat):
    model_dict = model.state_dict()

    for key, value in new_bn_stat.items():
        model_dict[key] = copy.deepcopy(value)  

    model.load_state_dict(model_dict, strict=True)


model_1 = nn.BatchNorm2d(3)
model_2 = nn.BatchNorm2d(3)

# check that both state dicts are equal
print(model_1.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([0., 0., 0.])), ('running_mean', tensor([0., 0., 0.])), ('running_var', tensor([1., 1., 1.])), ('num_batches_tracked', tensor(0))])
print(model_2.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([0., 0., 0.])), ('running_mean', tensor([0., 0., 0.])), ('running_var', tensor([1., 1., 1.])), ('num_batches_tracked', tensor(0))])

optimizer = torch.optim.SGD(model_1.parameters(), lr=1.)
# train model_1 for a few iterations
for _ in range(10):
    optimizer.zero_grad()
    x = torch.randn(1, 3, 224, 224)
    out = model_1(x)
    out.mean().backward()
    optimizer.step()

# check that model_1 was updated    
print(model_1.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-3.3333, -3.3333, -3.3333])), ('running_mean', tensor([-0.0010,  0.0017, -0.0002])), ('running_var', tensor([1.0015, 1.0008, 1.0005])), ('num_batches_tracked', tensor(10))])
print(model_2.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([0., 0., 0.])), ('running_mean', tensor([0., 0., 0.])), ('running_var', tensor([1., 1., 1.])), ('num_batches_tracked', tensor(0))])

bn_stat_1 = save_bn_stats(model_1)
change_bn_stats(model_2, bn_stat_1)

# check that state dicts are equal again
print(model_1.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-3.3333, -3.3333, -3.3333])), ('running_mean', tensor([-0.0010,  0.0017, -0.0002])), ('running_var', tensor([1.0015, 1.0008, 1.0005])), ('num_batches_tracked', tensor(10))])
print(model_2.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-3.3333, -3.3333, -3.3333])), ('running_mean', tensor([-0.0010,  0.0017, -0.0002])), ('running_var', tensor([1.0015, 1.0008, 1.0005])), ('num_batches_tracked', tensor(10))])

after fixing the wrong return encoder_dict statement in save_bn_stats as its undefined.

@ptrblck Thanks for your reply. I checked your example works fine.
But, I’m still wondering why I failed to store bn_stat_1 in this example.

import torch
from collections import OrderedDict
import copy
import torch.nn as nn


def save_bn_stats(model):
    model_encoder_dict = model.state_dict()

    cur_encoder_dict = OrderedDict()
    for key, value in model_encoder_dict.items():
        cur_encoder_dict[key] = model_encoder_dict[key]

    return cur_encoder_dict


def change_bn_stats(model, new_bn_stat):
    model_dict = model.state_dict()

    for key, value in new_bn_stat.items():
        model_dict[key] = copy.deepcopy(value)

    model.load_state_dict(model_dict, strict=True)


model_1 = nn.BatchNorm2d(3)
model_2 = nn.BatchNorm2d(3)

optimizer = torch.optim.SGD(model_1.parameters(), lr=1.)
# train model_1 for a few iterations

for step in range(5):
    optimizer.zero_grad()
    x = torch.randn(1, 3, 224, 224)
    out = model_1(x)
    out.mean().backward()
    optimizer.step()
    if step == 1:
        bn_stat_1 = save_bn_stats(model_1)
        print(bn_stat_1)
        # OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-0.6667, -0.6667, -0.6667])), ('running_mean', tensor([-7.7516e-04,  7.2945e-05, -1.9681e-04])), ('running_var', tensor([1.0006, 1.0015, 0.9999])), ('num_batches_tracked', tensor(2))])

    if step == 4:
        bn_stat_2 = save_bn_stats(model_1)
        print(bn_stat_2)
        # OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-1.6667, -1.6667, -1.6667])), ('running_mean', tensor([-1.3079e-03,  1.2919e-05,  5.0353e-04])), ('running_var', tensor([1.0012, 1.0015, 0.9987])), ('num_batches_tracked', tensor(5))])

change_bn_stats(model_1, bn_stat_1) # fail to load the bn_stat_1
print(model_1.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-1.6667, -1.6667, -1.6667])), ('running_mean', tensor([-1.3079e-03,  1.2919e-05,  5.0353e-04])), ('running_var', tensor([1.0012, 1.0015, 0.9987])), ('num_batches_tracked', tensor(2))])
change_bn_stats(model_1, bn_stat_2)
print(model_1.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-1.6667, -1.6667, -1.6667])), ('running_mean', tensor([-1.3079e-03,  1.2919e-05,  5.0353e-04])), ('running_var', tensor([1.0012, 1.0015, 0.9987])), ('num_batches_tracked', tensor(2))])

You are storing references in save_bn_stats and would need to clone the data or use copy.deepcopy instead.

1 Like