I am trying to use an update scheme for the moving average of the batch normalization in my model. Does anyone have an idea how to achive this?
Here is a snippet of my network:

for subnetwork in ['h_net', 'o_net', 'c_net']:
setattr(self, subnetwork, nn.Sequential(nn.Linear(input_dim, 80),
nn.ReLU(),
nn.Batchnorm1d(80),
nn.Linear(80, 40),
nn.ReLU(),
nn.BatchNorm(40),
nn.Linear(40, 1),
nn.ReLU()))

In the end I sum up the results of the different subnetworks.

Could you explain your use case a bit more as I’m not sure to understand it completely.
Currently you have three subnetworks using BatchNorm layers.
Would you like to change the update of the running averages, i.e. mean and std?
Or would you like to average the running averages of your three subnetworks somehow?

I want to change the update of the running averages according to a * b**(x/c), where a, b, c are choosen and x is the learning epoch. So at the end of every epoch I want to set the running averages.
Here is the a quote from the paper that I try to implement:
“If Batch Normalization is used, the parameter for moving average will be a * b**(x/c)”

@Stefaanhess, depending on your network structure, your code might not work. For example, if the children of layer have children with batch norm children, you will never touch them.

I use this to recursively search for all batch norm children to update:

def recursive_batch_norm_update(child):
if type(child) == torch.nn.BatchNorm1d:
child.momentum = a * b**(self.epoch/c)
return
for children in child.children():
lowest_child = recursive_layer(children)
return