While training my CNN, I need to apply weight decay only to a subset of the layers, which changes at each forward pass. Is it possible to set weight decay to 0 for some groups of weights dynamically at each forward pass?
2 Likes
You should be able yo change the weight_decay
for the current param_group
via:
# Setup
lin = nn.Linear(1, 1, bias=False)
optimizer = torch.optim.SGD(
lin.parameters(), lr=1., weight_decay=0.1)
# Store original weight
weight_ref = lin.weight.clone()
# Set gradient to zero (otherwise the step() op will be skipped)
lin.weight.grad = torch.zeros_like(lin.weight)
# Apply weight decay
optimizer.step()
# Store weights after weight decay
weight1 = lin.weight.clone()
# Remove weight decay for this param group
optimizer.param_groups[0]['weight_decay'] = 0.
# Step again
optimizer.step()
# Store weight to compare
weight2 = lin.weight.clone()
# Compare
print(weight_ref)
> tensor([[0.8899]], grad_fn=<CloneBackward>)
print(weight1)
> tensor([[0.8009]], grad_fn=<CloneBackward>)
print(weight2)
> tensor([[0.8009]], grad_fn=<CloneBackward>)
Note that I had to manually set the grad
to zeros, as otherwise the optimizer.step()
operation would just skip this parameter.