The Adamw paper says the Adam with weight decay looks like
And the corresponding pytorch implementation is
# Perform stepweight decay
p.data.mul_(1 - group['lr'] * group['weight_decay'])
I’m stuck by how line 12 in Algorithm 2(adamw) comes to the pytorch version.
I googled for a while and found that fast.ai published a post AdamW and Super-convergence is now the fastest way to train neural nets
, where it concluded that Adamw
might be implemented in some way like
loss.backward()
for group in optimizer.param_groups():
for param in group['params']:
param.data = param.data.add(-wd * group['lr'], param.data)
optimizer.step()
Am I missing something in order to derive from Algorithm 2 to the pytorch implementation?
Thank you for any elaborations.