Change weight decay during training

While training my CNN, I need to apply weight decay only to a subset of the layers, which changes at each forward pass. Is it possible to set weight decay to 0 for some groups of weights dynamically at each forward pass?

2 Likes

@ptrblck Any insights on this? I am looking for something like this as well :slight_smile:

You should be able yo change the weight_decay for the current param_group via:

# Setup
lin = nn.Linear(1, 1, bias=False)
optimizer = torch.optim.SGD(
    lin.parameters(), lr=1., weight_decay=0.1)

# Store original weight
weight_ref = lin.weight.clone()

# Set gradient to zero (otherwise the step() op will be skipped)
lin.weight.grad = torch.zeros_like(lin.weight)

# Apply weight decay
optimizer.step()

# Store weights after weight decay
weight1 = lin.weight.clone()

# Remove weight decay for this param group
optimizer.param_groups[0]['weight_decay'] = 0.

# Step again
optimizer.step()

# Store weight to compare
weight2 = lin.weight.clone()

# Compare
print(weight_ref)
> tensor([[0.8899]], grad_fn=<CloneBackward>)
print(weight1)
> tensor([[0.8009]], grad_fn=<CloneBackward>)
print(weight2)
> tensor([[0.8009]], grad_fn=<CloneBackward>)

Note that I had to manually set the grad to zeros, as otherwise the optimizer.step() operation would just skip this parameter.