How to optimizer weights in the same layer with different weight_decay values

I have a fully connected layer in torch.

self.fc = nn.ModuleList([nn.Sequential(nn.Linear(10, 1),self.activation)])

In the optimizer, I want to assign different weight_decay for the parameters in this layer. Some thing like (not correct code here)

optimizer = optim.Adam([{'params':self.fc.parameters()[0:5],'weight_decay':0.01},
                                        {'params':self.fc.parameters()[5:10],'weight_decay':0.01},])

Hi Paul!

If you want to use different values of weight_decay for different
parameters, use the parameter group facility of Optimizer.

However, if you want to use different weight decays for different
elements of the same parameter, things become more complicated.

The issue is that an entire tensor gets updated “all at once.” That is, you
can’t update some elements of a tensor one way and other elements of
the same tensor some other way (without indexing into the tensor “by
hand”).

One approach is to split the tensor in question up into multiple tensors
(and then put them into separate parameter groups that have different
weight_decay values). While splitting up tensors like this is certainly
doable, it tends to be a hassle.

Instead, you can recognize that weight decay is, in essence, the same
as applying a quadratic (L2) penalty to the weights. (Note, an optimizer
may treat a quadratic penalty and a weight_decay parameter somewhat
differently in detail.)

It’s then easy to give different quadratic penalties – and hence different
weight decays – to parts of the same tensor, say, some_parameter:

Let the tensor penalty_mask have the same shape as some_parameter
and consist, for example, of 1s in the locations of the elements of
some_parameter for which you want the weaker weight decay and 2s
for those elements for which you want the weight decay to be stronger
(in this example, twice as strong).

Then:

quadratic_penalty = some_scale_factor * penalty_mask * some_parameter**2
loss_plus_penalty = loss + quadratic_penalty
loss_plus_penalty.backward()

Best.

K. Frank

1 Like