Seperate Batchnorm params of multi-group parameters

I want to create an api method to remove weight decay from BatchNorm parameters, given multiple groups of parameters with different optimizer args from the user.
I was thinking of iterating over module.modules() to look for BN layers and keep their parameters identities. Then iterate over the param groups and compare those identities addresses, to create 2N param groups, with and without weight decay.
It seems to work well, can you think of cases that this logic might breaks? a simpler solution?

weight_decay = 1e-4
lr = 0.01
param_groups = [{'params': module.backbone.parameters(), 'lr': lr},
               {'params': module.head.parameters(), 'lr': lr * 10}]
optimizer_param_groups = split_weight_decay_params(module, param_groups, weight_decay)

def split_weight_decay_params(module: nn.Module, param_groups, weight_decay):
    # get all batchnorms param ids 
    batchnorm_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)
    no_decay_ids = []
    for m in module.modules():
        if isinstance(m, batchnorm_types):
            no_decay_ids.append(id(m.weight))
            no_decay_ids.append(id(m.bias))
    
    # split param groups for optimizer
    optimizer_param_groups = []
    for param_group in param_groups:
        no_decay_params = []
        decay_params = []
        for param in param_group["params"]:
            if id(param) in no_decay_ids:
                no_decay_params.append(param)
            else:
                decay_params.append(param)
        # append two param groups from the original param group, with and without weight decay.
        extra_optim_params = {key: param_group[key] for key in param_group if key != "params"}
        optimizer_param_groups.append({"params": no_decay_params, "weight_decay": 0.0, **extra_optim_params})
        optimizer_param_groups.append({"params": decay_params, "weight_decay": weight_decay, **extra_optim_params})
    
    return optimizer_param_groups
1 Like