Weight decay in the optimizers is a bad idea (especially with BatchNorm)

Correct me if I’m wrong, but there is no reason the beta and gamma parameters in BatchNorm should ever be subject to weight decay, ie L2 regularization, that pulls them toward 0. In fact it seems like a very bad idea to pull them toward 0. I know you can use Per-parameter options to get around the optimizers default behavior, but it seems like bad default behavior, especially with how commonly batchnorm is used. It seems better to only specify regularization where you want it rather than have the default apply it everywhere and thus even places it shouldn’t. I believe Pytorch could actually a cleaner regularization framework, perhaps similar to what we implemented in Keras. Thoughts?

11 Likes

It has been a year, but no feedback on it ?
I absolutely agree with @Michael_Oliver . For example, in the official example training on ImageNet https://github.com/pytorch/examples/tree/master/imagenet, it seems to me that Batch Norm Weight and Bias are added to Regularization Loss.
I think the correct way to implement it shoulde be:

optimizer = torch.optim.SGD( model.parameters(), args.lr,
                                momentum=args.momentum)
                                # ,weight_decay=args.weight_decay) #Remove weight decay in here
cls_loss = criterion(output, target)
reg_loss = 0
for name,param in model.named_parameters():
    if 'bn' not in name:
         reg_loss += torch.norm(param)
loss = cls_loss + args.weight_decay*reg_loss #Manual add weight decay

Please confirm or help with more elegant solutions. Thank you.

1 Like

I usually create a fn like add_weight_decay below. In current form it will add all batch norm parameters and bias layers to the no_decay list. I use this instead of looking for ‘bn’ strings in the name because that isn’t always consistent from model to model. It’s usually suggested that bias params should also not be decayed, so this does the job for me. You can still use the name if it works for your case though. Separating it otherwise would be a pain.

def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if len(param.shape) == 1 or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

When I create the optimizer, I put this block in front (usually all this is wrapped in a optim creation factory that also picks the optimizer to create from config or cmd args…

    weight_decay = args.weight_decay
    if weight_decay and filter_bias_and_bn:
        parameters = add_weight_decay(model, weight_decay)
        weight_decay = 0.
    else:
        parameters = model.parameters()

   if args.opt.lower() == 'sgd':
        optimizer = optim.SGD(
            parameters, lr=args.lr,
            momentum=args.momentum, weight_decay=weight_decay, nesterov=args.nesterov)
   ...
16 Likes

Nice. Thank you.
But I am just curious why even the official training code still ignores this fundamental error.

Will the weight decay on bn layers affect the performance of the model? Why?

Well, Weight decay basically pulls the norm of paramters to 0. In Batch norm, e.g x_hat = (x -beta)/gamma, you don’t want beta and gamma go to 0. Otherwise, BN is meaningless and erroneous.

This finding is also backed by this publication:

While I very much appreciate the code snippet provided by @rwightman , I too would like to see at least an option in PyTorch to exclude normalization layers from weight decay.

2 Likes

Please refer to the code of optimizer in PyTorch. In detail, after backward, the weight will be added to the grad of weight~(L2 weight decay). We could also directly use the above solution to avoid apply weight decay to bn. However, I have another more elegant method like function below:


def apply_weight_decay(*modules, weight_decay_factor=0., wo_bn=True):
    '''
    https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/5
    Apply weight decay to pytorch model without BN;
    In pytorch:
        if group['weight_decay'] != 0:
            grad = grad.add(p, alpha=group['weight_decay'])
    p is the param;
    :param modules:
    :param weight_decay_factor:
    :return:
    '''
    for module in modules:
        for m in module.modules():
            if hasattr(m, 'weight'):
                if wo_bn and isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
                    continue
                m.weight.grad += m.weight * weight_decay_factor

Please note that it should be applied after loss.backward. Besides, as far as I know, all bn layers in pytorch are inherited from _BatchNorm class.

I think you may have mistaken beta and gamma with the running mean and running variance.
The correct bn operation should be:
x_hat = [(x-running_mean)/running_var]*gamma + beta
So the beta and gamma are actually the bias and weight in bn layer.

@rwightman, i’m no expert, but as @JIALI_MA suggested, the running averages of batch norm layer are buffers not parameters, hence they can’t be in model.parameters()

import torch
import torch.nn as nn
# With Learnable Parameters
m = nn.BatchNorm2d(5)
print(dict(m.named_parameters()))

output:

{'weight': Parameter containing:
 tensor([1., 1., 1., 1., 1.], requires_grad=True),
 'bias': Parameter containing:
 tensor([0., 0., 0., 0., 0.], requires_grad=True)}
print(dict(m.named_buffers()))

output:

{'running_mean': tensor([0., 0., 0., 0., 0.]),
 'running_var': tensor([1., 1., 1., 1., 1.]),
 'num_batches_tracked': tensor(0)}

batch norm has learnable affine weight and bias parameters separate from the running mean and var

2 Likes

The following paper[1] shows experimental results that “the effect of regularization was concentrated on the BN layer,” and we should be cautious about making the behavior such that the weight decay of the BN layer is off by default.

As evidence, we found that almost all of the regularization effect of weight decay was due to
applying it to layers with BN (for which weight decay is meaningless).

It seems that the mechanism of weight decay is not fully understood even in the research field. At least until there is a clear empirical and theoretical basis, the above modification should be withheld.

[1] [1810.12281] Three Mechanisms of Weight Decay Regularization

I would also be interested in a simple option for weight decay not to apply to batchnorm. fast.ai, citing Tencent, reports that removing weight decay for batchnorm was helpful.

The intended logic seems more readable to me when using isinstance(). Here’s an add_weight_decay() variant of @Hzzone’s apply_weight_decay()

def add_weight_decay(
        model, 
        weight_decay=1e-5, 
        skip_list=(nn.InstanceNorm, nn.BatchNorm)):
    """Using .modules() with an isinstance() check"""
    decay = []
    no_decay = []
    for module in model.modules():
        params = [p for p in module.parameters() if p.requires_grad]
        if isinstance(module, skip_list):
            no_decay.extend(params)
        else:
            decay.extend(params)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

However, I found another research stating that weight decay is good for BN in some network architectures.

Unlike weight decay on weights in e.g. convolutional layers, which typically directly precede nor-
malization layers, weight decay on γ and β can have a regularization effect so long as there is a
path in the network between the layer in question and the ultimate output of the network, as if such
paths do not pass through another normalization layer, then the weight decay is never “undone” by
normalization.

Regularization in the form of weight decay on the normalization parameters γ and
β can be applied to any normalization layer, but is only effective in architectures with particular
connectivity properties like ResNets and in tasks for which models are already overfitting.

in https://arxiv.org/pdf/1906.03548.pdf