Weight decay in the optimizers is a bad idea (especially with BatchNorm)

Correct me if I’m wrong, but there is no reason the beta and gamma parameters in BatchNorm should ever be subject to weight decay, ie L2 regularization, that pulls them toward 0. In fact it seems like a very bad idea to pull them toward 0. I know you can use Per-parameter options to get around the optimizers default behavior, but it seems like bad default behavior, especially with how commonly batchnorm is used. It seems better to only specify regularization where you want it rather than have the default apply it everywhere and thus even places it shouldn’t. I believe Pytorch could actually a cleaner regularization framework, perhaps similar to what we implemented in Keras. Thoughts?

7 Likes

It has been a year, but no feedback on it ?
I absolutely agree with @Michael_Oliver . For example, in the official example training on ImageNet https://github.com/pytorch/examples/tree/master/imagenet, it seems to me that Batch Norm Weight and Bias are added to Regularization Loss.
I think the correct way to implement it shoulde be:

optimizer = torch.optim.SGD( model.parameters(), args.lr,
                                momentum=args.momentum)
                                # ,weight_decay=args.weight_decay) #Remove weight decay in here
cls_loss = criterion(output, target)
reg_loss = 0
for name,param in model.named_parameters():
    if 'bn' not in name:
         reg_loss += torch.norm(param)
loss = cls_loss + args.weight_decay*reg_loss #Manual add weight decay

Please confirm or help with more elegant solutions. Thank you.

1 Like

I usually create a fn like add_weight_decay below. In current form it will add all batch norm parameters and bias layers to the no_decay list. I use this instead of looking for ‘bn’ strings in the name because that isn’t always consistent from model to model. It’s usually suggested that bias params should also not be decayed, so this does the job for me. You can still use the name if it works for your case though. Separating it otherwise would be a pain.

def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if len(param.shape) == 1 or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

When I create the optimizer, I put this block in front (usually all this is wrapped in a optim creation factory that also picks the optimizer to create from config or cmd args…

    weight_decay = args.weight_decay
    if weight_decay and filter_bias_and_bn:
        parameters = add_weight_decay(model, weight_decay)
        weight_decay = 0.
    else:
        parameters = model.parameters()

   if args.opt.lower() == 'sgd':
        optimizer = optim.SGD(
            parameters, lr=args.lr,
            momentum=args.momentum, weight_decay=weight_decay, nesterov=args.nesterov)
   ...
11 Likes

Nice. Thank you.
But I am just curious why even the official training code still ignores this fundamental error.

Will the weight decay on bn layers affect the performance of the model? Why?

Well, Weight decay basically pulls the norm of paramters to 0. In Batch norm, e.g x_hat = (x -beta)/gamma, you don’t want beta and gamma go to 0. Otherwise, BN is meaningless and erroneous.

This finding is also backed by this publication:

While I very much appreciate the code snippet provided by @rwightman , I too would like to see at least an option in PyTorch to exclude normalization layers from weight decay.

2 Likes

Please refer to the code of optimizer in PyTorch. In detail, after backward, the weight will be added to the grad of weight~(L2 weight decay). We could also directly use the above solution to avoid apply weight decay to bn. However, I have another more elegant method like function below:


def apply_weight_decay(*modules, weight_decay_factor=0., wo_bn=True):
    '''
    https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/5
    Apply weight decay to pytorch model without BN;
    In pytorch:
        if group['weight_decay'] != 0:
            grad = grad.add(p, alpha=group['weight_decay'])
    p is the param;
    :param modules:
    :param weight_decay_factor:
    :return:
    '''
    for module in modules:
        for m in module.modules():
            if hasattr(m, 'weight'):
                if wo_bn and isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
                    continue
                m.weight.grad += m.weight * weight_decay_factor

Please note that it should be applied after loss.backward. Besides, as far as I know, all bn layers in pytorch are inherited from _BatchNorm class.