In the documentation/code of ADAM here I noticed something that seems like an error to me. Even though torch.view_as_real
is called on grad
, grad.conj()
is called later for the computation of exp_avg_sq
. Unless I misunderstood the code - this seems wrong to me. We want to compute the absolute value of the gradient here right?
Can anyone clarify if this is a bug? Here is the code in question:
if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
exp_avg_sq = torch.view_as_real(exp_avg_sq)
if amsgrad:
max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
param = torch.view_as_real(param)
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)