Weight decay only for weights of nn.Linear and nn.Conv*

In many of the papers and blogs that I read, for example, the recent NFNet paper, the authors emphasize the importance of only including the convolution & linear layer weights in weight decay. Bias values for all layers, as well as the weight and bias values of normalization layers, e.g., LayerNorm, should be excluded from weight decay.
However, setting different weight decay values for different classes in the model is not an easy matter with PyTorch optimizers. The official documentation on per-parameter options means that I can set separate groups for the .weight attributes of the desired classes and put the rest in a separate group where weight decay is not applied.
However, the difficulty is in separating the classes of the sub-modules, which are often buried inside nn.Sequential layers without names.
I am currently trying to use the .modules() method to separate the sub-modules but I have no idea whether this is correct.

no_decay = list()
decay = list()
for m in model.modules():
  if isinstance(m, (nn.Linear, nn.Conv2d, ...)):
    decay.append(m.weight)
    no_decay.append(m.bias)
  elif hasattr(m, 'weight'):
    no_decay.append(m.weight)
  elif hasattr(m, 'bias'):
    no_decay.append(m.bias)

optimizer = optim.AdamW([{'params': no_decay, 'weight_decay', 0}, {'params': decay}, **kwargs])

The code above has the obvious problem that trainable parameters that are not called weight or bias are excluded completely from training.
Is there any better method to do this?

1 Like

Your approach seems reasonable. Iterating all named parameter of the model might be a bit easier:

model = models.resnet152()

decay = dict()
no_decay = dict()
for name, m in model.named_parameters():
    print('checking {}'.format(name))
    if 'weight' in name:
        decay[name] = param
    else:
        no_decay[name] = param
    
print(decay.keys())
print(no_decay.keys())
1 Like

@ptrblck Thank you for the solution. However, I believe that this requires naming all layers that go inside an nn.Sequential layer. By default, layers within an nn.Sequential do not retain their names.

As an alternative, I have considered adding the following code.

no_decay = list()
decay = list()
for m in model.modules():
  if isinstance(m, (nn.Linear, nn.Conv2d, ...)):
    decay.append(m.weight)
    no_decay.append(m.bias)
  elif hasattr(m, 'weight'):
    no_decay.append(m.weight)
  elif hasattr(m, 'bias'):
    no_decay.append(m.bias)
for name, param in self.model.named_modules():
  if not (name.endswith('.weight') or name.endswith('.bias')):
    no_decay.append(param)
optimizer = optim.AdamW([{'params': no_decay, 'weight_decay', 0}, {'params': decay}, **kwargs])

The reason for extracting only the weight and bias values is that .modules() returns all modules, including modules that contain other modules, whereas .named_parameters() only returns the parameters at the very end of the recursion.

nn.Sequential modules will add the index to the parameter names, such as 0.weight, 1.weight etc.

I might misunderstand the use case, but I thought that’s exactly what you need: the actual parameters at the most inner level.

Hello.
The reason that I used .modules() was to exclude normalization layers such as nn.LayerNorm and nn.BatchNorm.
These layers also have parameters named weight and bias but these should be excluded from weight decay.
However, within nn.Sequential, it is impossible to tell the type of layer from their names unless their names have been specified. This is why I use .modules() instead of .named_parameters().

On second thoughts, it is much better to use Python sets.

all_params = set(model.parameters())
wd_params = set()
for m in model.modules():
  if isinstance(m, (nn.Linear, nn.Conv*)):
    wd_params.add(m.weight)
no_wd = all_params - wd_params

Do you happen to have an alternative to using sets? This PyTorch Page on optimizers says explicitly that sets do not satisfy the requirements when defining the optimizer.

Parameters need to be specified as collections that have a deterministic ordering that is consistent between runs. Examples of objects that don’t satisfy those properties are sets and iterators over values of dictionaries.

Thanks in advance!

@john90 Thank you for pointing this out! It may have been causing unknown issues in my code for a very long time.

I have a simple solution that does not use any undefined orderings.

import torch
from torch import nn


@torch.no_grad()
def get_wd_params(model: nn.Module):
    # Parameters must have a defined order. 
    # No sets or dictionary iterations.
    # See https://pytorch.org/docs/stable/optim.html#base-class
    # Parameters for weight decay.
    all_params = tuple(model.parameters())
    wd_params = list()
    for m in model.modules():
        if isinstance(
                m,
                (
                        nn.Linear,
                        nn.Conv1d,
                        nn.Conv2d,
                        nn.Conv3d,
                        nn.ConvTranspose1d,
                        nn.ConvTranspose2d,
                        nn.ConvTranspose3d,
                ),
        ):
            wd_params.append(m.weight)
    # Only weights of specific layers should undergo weight decay.
    no_wd_params = [p for p in all_params if p not in wd_params]
    assert len(wd_params) + len(no_wd_params) == len(all_params), "Sanity check failed."
    return wd_params, no_wd_params

Hi, I have the same problem and googled here. Below is my solution, which combines both solutions mentioned above.

@torch.no_grad()
def get_wd_params(model: nn.Module):
    decay = list()
    no_decay = list()
    for name, param in model.named_parameters():
        print('checking {}'.format(name))
        if hasattr(param,'requires_grad') and not param.requires_grad:
            continue
        if 'weight' in name and 'norm' not in name and 'bn' not in name:
            decay.append(param)
        else:
            no_decay.append(param)
    return decay, no_decay

Fortunately, I found that I had been converting all sets to tuples before giving them to the optimizer, so there were no problems.