Weight decay in the optimizers is a bad idea (especially with BatchNorm)

Correct me if I’m wrong, but there is no reason the beta and gamma parameters in BatchNorm should ever be subject to weight decay, ie L2 regularization, that pulls them toward 0. In fact it seems like a very bad idea to pull them toward 0. I know you can use Per-parameter options to get around the optimizers default behavior, but it seems like bad default behavior, especially with how commonly batchnorm is used. It seems better to only specify regularization where you want it rather than have the default apply it everywhere and thus even places it shouldn’t. I believe Pytorch could actually a cleaner regularization framework, perhaps similar to what we implemented in Keras. Thoughts?

11 Likes

It has been a year, but no feedback on it ?
I absolutely agree with @Michael_Oliver . For example, in the official example training on ImageNet https://github.com/pytorch/examples/tree/master/imagenet, it seems to me that Batch Norm Weight and Bias are added to Regularization Loss.
I think the correct way to implement it shoulde be:

optimizer = torch.optim.SGD( model.parameters(), args.lr,
                                momentum=args.momentum)
                                # ,weight_decay=args.weight_decay) #Remove weight decay in here
cls_loss = criterion(output, target)
reg_loss = 0
for name,param in model.named_parameters():
    if 'bn' not in name:
         reg_loss += torch.norm(param)
loss = cls_loss + args.weight_decay*reg_loss #Manual add weight decay

Please confirm or help with more elegant solutions. Thank you.

1 Like

I usually create a fn like add_weight_decay below. In current form it will add all batch norm parameters and bias layers to the no_decay list. I use this instead of looking for ‘bn’ strings in the name because that isn’t always consistent from model to model. It’s usually suggested that bias params should also not be decayed, so this does the job for me. You can still use the name if it works for your case though. Separating it otherwise would be a pain.

def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if len(param.shape) == 1 or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

When I create the optimizer, I put this block in front (usually all this is wrapped in a optim creation factory that also picks the optimizer to create from config or cmd args…

    weight_decay = args.weight_decay
    if weight_decay and filter_bias_and_bn:
        parameters = add_weight_decay(model, weight_decay)
        weight_decay = 0.
    else:
        parameters = model.parameters()

   if args.opt.lower() == 'sgd':
        optimizer = optim.SGD(
            parameters, lr=args.lr,
            momentum=args.momentum, weight_decay=weight_decay, nesterov=args.nesterov)
   ...
17 Likes

Nice. Thank you.
But I am just curious why even the official training code still ignores this fundamental error.

Will the weight decay on bn layers affect the performance of the model? Why?

1 Like

Well, Weight decay basically pulls the norm of paramters to 0. In Batch norm, e.g x_hat = (x -beta)/gamma, you don’t want beta and gamma go to 0. Otherwise, BN is meaningless and erroneous.

This finding is also backed by this publication:

While I very much appreciate the code snippet provided by @rwightman , I too would like to see at least an option in PyTorch to exclude normalization layers from weight decay.

2 Likes

Please refer to the code of optimizer in PyTorch. In detail, after backward, the weight will be added to the grad of weight~(L2 weight decay). We could also directly use the above solution to avoid apply weight decay to bn. However, I have another more elegant method like function below:


def apply_weight_decay(*modules, weight_decay_factor=0., wo_bn=True):
    '''
    https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/5
    Apply weight decay to pytorch model without BN;
    In pytorch:
        if group['weight_decay'] != 0:
            grad = grad.add(p, alpha=group['weight_decay'])
    p is the param;
    :param modules:
    :param weight_decay_factor:
    :return:
    '''
    for module in modules:
        for m in module.modules():
            if hasattr(m, 'weight'):
                if wo_bn and isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
                    continue
                m.weight.grad += m.weight * weight_decay_factor

Please note that it should be applied after loss.backward. Besides, as far as I know, all bn layers in pytorch are inherited from _BatchNorm class.

I think you may have mistaken beta and gamma with the running mean and running variance.
The correct bn operation should be:
x_hat = [(x-running_mean)/running_var]*gamma + beta
So the beta and gamma are actually the bias and weight in bn layer.

@rwightman, i’m no expert, but as @JIALI_MA suggested, the running averages of batch norm layer are buffers not parameters, hence they can’t be in model.parameters()

import torch
import torch.nn as nn
# With Learnable Parameters
m = nn.BatchNorm2d(5)
print(dict(m.named_parameters()))

output:

{'weight': Parameter containing:
 tensor([1., 1., 1., 1., 1.], requires_grad=True),
 'bias': Parameter containing:
 tensor([0., 0., 0., 0., 0.], requires_grad=True)}
print(dict(m.named_buffers()))

output:

{'running_mean': tensor([0., 0., 0., 0., 0.]),
 'running_var': tensor([1., 1., 1., 1., 1.]),
 'num_batches_tracked': tensor(0)}

batch norm has learnable affine weight and bias parameters separate from the running mean and var

2 Likes

The following paper[1] shows experimental results that “the effect of regularization was concentrated on the BN layer,” and we should be cautious about making the behavior such that the weight decay of the BN layer is off by default.

As evidence, we found that almost all of the regularization effect of weight decay was due to
applying it to layers with BN (for which weight decay is meaningless).

It seems that the mechanism of weight decay is not fully understood even in the research field. At least until there is a clear empirical and theoretical basis, the above modification should be withheld.

[1] [1810.12281] Three Mechanisms of Weight Decay Regularization

I would also be interested in a simple option for weight decay not to apply to batchnorm. fast.ai, citing Tencent, reports that removing weight decay for batchnorm was helpful.

The intended logic seems more readable to me when using isinstance(). Here’s an add_weight_decay() variant of @Hzzone’s apply_weight_decay()

def add_weight_decay(
        model, 
        weight_decay=1e-5, 
        skip_list=(nn.InstanceNorm, nn.BatchNorm)):
    """Using .modules() with an isinstance() check"""
    decay = []
    no_decay = []
    for module in model.modules():
        params = [p for p in module.parameters() if p.requires_grad]
        if isinstance(module, skip_list):
            no_decay.extend(params)
        else:
            decay.extend(params)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

However, I found another research stating that weight decay is good for BN in some network architectures.

Unlike weight decay on weights in e.g. convolutional layers, which typically directly precede nor-
malization layers, weight decay on γ and β can have a regularization effect so long as there is a
path in the network between the layer in question and the ultimate output of the network, as if such
paths do not pass through another normalization layer, then the weight decay is never “undone” by
normalization.

Regularization in the form of weight decay on the normalization parameters γ and
β can be applied to any normalization layer, but is only effective in architectures with particular
connectivity properties like ResNets and in tasks for which models are already overfitting.

in https://arxiv.org/pdf/1906.03548.pdf

What you implement here is not L2 regularization. For L2 regularization, you want to add the squared norm, nor the norm itself. (I have seen that sometimes you also divide the squared norm by 2)

Also, are you sure that the test 'bn' not in name is sufficient for determining whether a parameter belongs to a batch normalization layer?

Otherwise, I think your approach looks good.

in https://arxiv.org/pdf/1906.03548.pdf

“He et al. (2019) encourage disabling weight decay on γ and β, but ultimately find diminished performance by doing so.” Interesting.

For ResNets, the only parameters that it makes sense to regularize using L2 regularization is essentially the γ parameters of the last batch normalization layer in each block. L2 regularizing any other parameter doesn’t have any effect (as long as you L2 regularize all of them in tandem)—except making the parameters smaller in magnitude and thus more sensitive to change, which actually may mess with the learning rate in a big way—see the paper On the Periodic Behavior of Neural Network Training with Batch Normalization and Weight Decay.

Absolutely : “why even the official training code still ignores this fundamental error”

Essentially this means the (in a lot of cases, requiring any norm in particular) weight decay is completely broken “out of the box” in Pytorch. i.e. you cannot just set a weight decay parameter in the optimiser and expect it to work. I had a hard time even finding this forum when I hit this problem. I was very puzzled why AdamW did not work, and weight decay did not work. Only in a harder-to-find forum like this is it discussed at all.

It would seem very straightforward to allow all norm layers and biases to be weight decay of zero in the optimiser code. It doesn’t even need to be the default setting (as this could break existing code). It can be a additional option as already mentioned here already (4 years ago …)