Update scheme for Batchnorm momentum

I am trying to use an update scheme for the moving average of the batch normalization in my model. Does anyone have an idea how to achive this?
Here is a snippet of my network:

for subnetwork in ['h_net', 'o_net', 'c_net']:
    setattr(self, subnetwork, nn.Sequential(nn.Linear(input_dim, 80),
                                            nn.ReLU(),
                                            nn.Batchnorm1d(80),
                                            nn.Linear(80, 40),
                                            nn.ReLU(),
                                            nn.BatchNorm(40),
                                            nn.Linear(40, 1),
                                            nn.ReLU()))

In the end I sum up the results of the different subnetworks.

Could you explain your use case a bit more as I’m not sure to understand it completely.
Currently you have three subnetworks using BatchNorm layers.
Would you like to change the update of the running averages, i.e. mean and std?
Or would you like to average the running averages of your three subnetworks somehow?

I want to change the update of the running averages according to a * b**(x/c), where a, b, c are choosen and x is the learning epoch. So at the end of every epoch I want to set the running averages.
Here is the a quote from the paper that I try to implement:
“If Batch Normalization is used, the parameter for moving average will be a * b**(x/c)”

Your formula a*b**(epoch/c) would therefore be the momentum term?
I.e. your running stats would be updated for each batch as:

momentum = a * b**(epoch/c)
x_new_running = (1 - momentum) * x_running + momentum * x_new_observed

Thanks for your reply!
I tried to update the momentum with the following code:

        for subnetwork in self.children():
            for layer in subnetwork.children():
                if type(layer) == torch.nn.BatchNorm1d:
                    layer.momentum = a * b**(self.epoch/c)

which seems to work.

@Stefaanhess, depending on your network structure, your code might not work. For example, if the children of layer have children with batch norm children, you will never touch them.

I use this to recursively search for all batch norm children to update:

def recursive_batch_norm_update(child):
    
    if type(child) == torch.nn.BatchNorm1d:
        child.momentum = a * b**(self.epoch/c)
        return
    
    for children in child.children():
        lowest_child = recursive_layer(children)
    
    return

then just call

recursive_batch_norm_update(network)

I think model.modules() will also work. It search sub_modules recursively.

1 Like