Understand pytorch optimization, weight decay

I want to understand how weight decay (L2 penalty) is working:
`torch.optim.Adam(params , lr=0.001 , betas=(0.9 , 0.999) , eps=1e-08 , weight_decay=0 , amsgrad=False ).
How L2 penalty actually work during optimization?

It is more obvious in older (1.6) sources, weight_decay only affects one line:

d_p = d_p.add(p, alpha=weight_decay)

i.e. gradient = gradient + param_value * weight_decay
so, as param_value deviates from zero, taken steps increase

So, it is changing constraining the current gradient with the current parameters?

Um, of course “current” values are used. Mathematically, effect of this is the same as of usual L2 decay, it is just applied via gradient.

1 Like