Smoothed Update of Network Parameters

Hello,

I have two networks, target_value and value_state dict.

I am trying to update parameters of the target_value network, I want the update to be a smoothed version of target_value and value_state. I have the following code, which works well if my networks do not contain a batch normalisation layer.

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        target_value_params = self.target_value.named_parameters()
        value_params = self.value.named_parameters()

        target_value_state_dict = dict(target_value_params)
        value_state_dict = dict(value_params)

        for name in value_state_dict:
            value_state_dict[name] = tau*value_state_dict[name].clone() + \
                                     (1-tau)*target_value_state_dict[name].clone()

        self.target_value.load_state_dict(value_state_dict)

However, once I include the batch normalisation layers, I am returned with the following error:

RuntimeError: Error(s) in loading state_dict for ValueNetwork:
	Missing key(s) in state_dict: "feature_extractor.extractors.price_features.co1.bn1.running_mean", "feature_extractor.extractors.price_features.co1.bn1.running_var", "feature_extractor.extractors.price_features.co2.bn1.running_mean", "feature_extractor.extractors.price_features.co2.bn1.running_var"

I understand that the error is present since calling .named_parameters() will not load the batch norm state dict.

My question’s are:

  1. After performing the smooth update of parameters, what is the correct way to update the batch norm statistics i.e. running mean and running_var
  2. Are you able to able provide some code snippet so I can understand how to do the above

Thank you!

p.s. @ptrblck I saw a similar topic you replied to in the past, however, I can no longer find the thread.

skip rewriting these with strict=False argument

disregard that, your problem is that you’re not including buffers in state_dict (try Module.buffers()). I’m not sure about fixing the stats, I think correct values won’t be available without re-evaluation, but maybe error won’t be significant.

Hi @ssha,
what you want to do is really close to stochastic-weight-averaging.

torch.optim.swa_utils

You can check the source to see how they implementing it.
About how to update running means, you can see function update_bn in blog post.