I have two networks, target_value and value_state dict.
I am trying to update parameters of the target_value network, I want the update to be a smoothed version of target_value and value_state. I have the following code, which works well if my networks do not contain a batch normalisation layer.
def update_network_parameters(self, tau=None): if tau is None: tau = self.tau target_value_params = self.target_value.named_parameters() value_params = self.value.named_parameters() target_value_state_dict = dict(target_value_params) value_state_dict = dict(value_params) for name in value_state_dict: value_state_dict[name] = tau*value_state_dict[name].clone() + \ (1-tau)*target_value_state_dict[name].clone() self.target_value.load_state_dict(value_state_dict)
However, once I include the batch normalisation layers, I am returned with the following error:
RuntimeError: Error(s) in loading state_dict for ValueNetwork: Missing key(s) in state_dict: "feature_extractor.extractors.price_features.co1.bn1.running_mean", "feature_extractor.extractors.price_features.co1.bn1.running_var", "feature_extractor.extractors.price_features.co2.bn1.running_mean", "feature_extractor.extractors.price_features.co2.bn1.running_var"
I understand that the error is present since calling
.named_parameters() will not load the batch norm state dict.
My question’s are:
- After performing the smooth update of parameters, what is the correct way to update the batch norm statistics i.e. running mean and running_var
- Are you able to able provide some code snippet so I can understand how to do the above
p.s. @ptrblck I saw a similar topic you replied to in the past, however, I can no longer find the thread.