Your code seems to work for me:
def save_bn_stats(model):
model_encoder_dict = model.state_dict()
cur_encoder_dict = OrderedDict()
for key, value in model_encoder_dict.items():
cur_encoder_dict[key] = model_encoder_dict[key]
return cur_encoder_dict
def change_bn_stats(model, new_bn_stat):
model_dict = model.state_dict()
for key, value in new_bn_stat.items():
model_dict[key] = copy.deepcopy(value)
model.load_state_dict(model_dict, strict=True)
model_1 = nn.BatchNorm2d(3)
model_2 = nn.BatchNorm2d(3)
# check that both state dicts are equal
print(model_1.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([0., 0., 0.])), ('running_mean', tensor([0., 0., 0.])), ('running_var', tensor([1., 1., 1.])), ('num_batches_tracked', tensor(0))])
print(model_2.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([0., 0., 0.])), ('running_mean', tensor([0., 0., 0.])), ('running_var', tensor([1., 1., 1.])), ('num_batches_tracked', tensor(0))])
optimizer = torch.optim.SGD(model_1.parameters(), lr=1.)
# train model_1 for a few iterations
for _ in range(10):
optimizer.zero_grad()
x = torch.randn(1, 3, 224, 224)
out = model_1(x)
out.mean().backward()
optimizer.step()
# check that model_1 was updated
print(model_1.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-3.3333, -3.3333, -3.3333])), ('running_mean', tensor([-0.0010, 0.0017, -0.0002])), ('running_var', tensor([1.0015, 1.0008, 1.0005])), ('num_batches_tracked', tensor(10))])
print(model_2.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([0., 0., 0.])), ('running_mean', tensor([0., 0., 0.])), ('running_var', tensor([1., 1., 1.])), ('num_batches_tracked', tensor(0))])
bn_stat_1 = save_bn_stats(model_1)
change_bn_stats(model_2, bn_stat_1)
# check that state dicts are equal again
print(model_1.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-3.3333, -3.3333, -3.3333])), ('running_mean', tensor([-0.0010, 0.0017, -0.0002])), ('running_var', tensor([1.0015, 1.0008, 1.0005])), ('num_batches_tracked', tensor(10))])
print(model_2.state_dict())
# OrderedDict([('weight', tensor([1., 1., 1.])), ('bias', tensor([-3.3333, -3.3333, -3.3333])), ('running_mean', tensor([-0.0010, 0.0017, -0.0002])), ('running_var', tensor([1.0015, 1.0008, 1.0005])), ('num_batches_tracked', tensor(10))])
after fixing the wrong return encoder_dict
statement in save_bn_stats
as its undefined.