NaN gradient after complex square root of 0 value

Can anyone think of a nice work-around for this issue?

>>> torch.__version__
'2.0.1+cu118'
>>> z = torch.tensor(0. + 0.j, requires_grad=True)
>>> z.sqrt().abs().backward()
>>> z.grad
tensor(nan+nanj)

(I’m willing to post an issue, but I’d like to characterize the problem a bit more. Does it look like a bug? I haven’t properly dived into the math behind wirtinger derivatives yet)

I suspect I could work around this with something like torch.where(z != 0, z, epsilon) or by zero’ing out all nans but both seem rather awkward with complex numbers / gradients.

Hi Rehno!

This is not a bug – it is the mathematically correct result (and is to be
expected).

The square-root function has a branch cut at z = 0. + 0.j and its
derivative there is not well defined, so nan is the appropriate result.

It would indeed be awkward. The core problem is that you want to compute
a derivative at the singular point z = 0. + 0.j.

Let’s say you have some small, positive, real epsilon. You might imagine
trying to clamp z at some epsilon away from zero. But should you set
z = epsilon or z = -epsilon or z = 1.j * epsilon, and so on? The
ambiguity remains, even if you try to sweep it under the rug.

Figure out in the context of your larger computation why you are trying to
compute the gradient of z.sqrt() when z is zero, and avoid doing that
in the first place.

Best.

K. Frank