Hello,
I have two networks, target_value and value_state dict.
I am trying to update parameters of the target_value network, I want the update to be a smoothed version of target_value and value_state. I have the following code, which works well if my networks do not contain a batch normalisation layer.
def update_network_parameters(self, tau=None):
if tau is None:
tau = self.tau
target_value_params = self.target_value.named_parameters()
value_params = self.value.named_parameters()
target_value_state_dict = dict(target_value_params)
value_state_dict = dict(value_params)
for name in value_state_dict:
value_state_dict[name] = tau*value_state_dict[name].clone() + \
(1-tau)*target_value_state_dict[name].clone()
self.target_value.load_state_dict(value_state_dict)
However, once I include the batch normalisation layers, I am returned with the following error:
RuntimeError: Error(s) in loading state_dict for ValueNetwork:
Missing key(s) in state_dict: "feature_extractor.extractors.price_features.co1.bn1.running_mean", "feature_extractor.extractors.price_features.co1.bn1.running_var", "feature_extractor.extractors.price_features.co2.bn1.running_mean", "feature_extractor.extractors.price_features.co2.bn1.running_var"
I understand that the error is present since calling .named_parameters()
will not load the batch norm state dict.
My question’s are:
- After performing the smooth update of parameters, what is the correct way to update the batch norm statistics i.e. running mean and running_var
- Are you able to able provide some code snippet so I can understand how to do the above
Thank you!
p.s. @ptrblck I saw a similar topic you replied to in the past, however, I can no longer find the thread.