Hi Youlong!
Looking at the original Adam paper linked to in pytorch’s
Adam
documentation, I believe this is the way Adam
is
supposed to work. (I have not looked at pytorch’s Adam
implementation.)
Quoting from the abstract of the Adam paper:
The method … is invariant to diagonal rescaling of the gradients
Note, that the direction of the gradient in some sense “washes out”
is portrayed as a desirable feature of Adam
.
(Some suggestive terminology: The “m” in “Adam” refers to “moment”
rather than “momentum.” It is true that Adam
"accumulates moments,
but in its very first step it moves a a direction that scales like
gradient / sqrt (gradient**2)
, which is to say, it does not
(necessarily) move in the direction of the gradient.)
We can verify that this behavior is displayed by pytorch’s Adam
:
>>> import torch
>>> torch.__version__
'1.9.0'
>>> def fs (t):
... return (t * t * torch.tensor ([1.0, 1.0])).sum()
...
>>> def fa (t):
... return (t * t * torch.tensor ([2.0, 0.5])).sum()
...
>>> tss = torch.tensor ([1.0, 1.0], requires_grad = True)
>>> tsa = torch.tensor ([1.0, 1.0], requires_grad = True)
>>> tas = torch.tensor ([1.0, 1.0], requires_grad = True)
>>> taa = torch.tensor ([1.0, 1.0], requires_grad = True)
>>> sgds = torch.optim.SGD ([tss], lr = 0.1)
>>> sgda = torch.optim.SGD ([tsa], lr = 0.1)
>>> adas = torch.optim.Adam ([tas])
>>> adaa = torch.optim.Adam ([taa])
>>> fs (tss).backward()
>>> tss.grad
tensor([2., 2.])
>>> fa (tsa).backward()
>>> tsa.grad
tensor([4., 1.])
>>> fs (tas).backward()
>>> tas.grad
tensor([2., 2.])
>>> fa (taa).backward()
>>> taa.grad
tensor([4., 1.])
>>> sgds.step()
>>> tss
tensor([0.8000, 0.8000], requires_grad=True)
>>> sgda.step()
>>> tsa
tensor([0.6000, 0.9000], requires_grad=True)
>>> adas.step()
>>> tas
tensor([0.9990, 0.9990], requires_grad=True)
>>> adaa.step()
>>> taa
tensor([0.9990, 0.9990], requires_grad=True)
Here, fa (t)
is a function that has a larger gradient in the t[0]
direction than in the t[1]
direction (while fs (s)
is symmetrical).
But you can see that Adam
takes a step directly toward the origin
(the minimum), and this is not in the direction of the gradient of
fa (t)
. (By way of comparison, the steps taken by SGD
are in
the direction of the gradient.)
Best.
K. Frank