Differentiating a Complex Valued Function

Hi,

If I wanted to parameterize a function f: R → C via a real valued neural network followed by some operations that make it complex, how are the gradients and higher-order derivatives of f calculated?

As a simple example below, suppose I have a single linear layer and then multiply the output by a complex number. If I perform backpropagation, it seems that the gradients are only with respect to the real part of the output. In the below, the gradient of the weight matrix should be [[(2+3i)1, (2+3i)2], [(2+3i)1, (2+3i)2], but instead it returns [[2x1, 2x2], [2x1, 2x2]].

import torch

linear_layer = torch.nn.Linear(2, 2)
test = linear_layer.forward(torch.Tensor([[1., 2.]]))
(test*(2+3j)).sum().backward()
for i in linear_layer.parameters(): print(i.grad)

If this is what is going on, could I just take the imaginary part and call backward on that?

(test*(2+3j)).imag().sum().backward()

This seems to work for this toy example, but wanted to get some confirmation. Also, I am curious if this also holds for higher order derivatives calculated using functorch. E.g. if I take a hessian, will it only take the hessian of the real part of the output?

Hi bchen!

The answer is rather nuanced and the best (pytorch) explanation is here
in the Autograd-for-Complex-Numbers documentation.

There are two things going on here:

First, when you call .backward(), it uses, by default gradient = None.

The pytorch framework is designed to optimize real-valued loss functions.
Because of this, when you call some_complex_loss.backward() pytorch
computes (in effect) some_complex_loss.real.backward(). (I’m pretty
sure of this, but you might want to scour the above documentation to see
if I’ve missed some nuance.)

Note, in particular, that .backward() computes, in the language of
Wirtinger derivatives, d / dz* (not d / dz).

Second, you are computing the gradient with respect to a real tensor.
(By default, Linear is instantiated with real weights.) This causes
backward() not to store (or maybe not even compute) the full complex
gradient – presumably because tensors and their .grad properties are
required to be of the same type.

If you need the gradient of the imaginary part, yes, call .backward()
on .imag. (Be sure to take into account that .imag is real so that
t_complex = t_complex.real + 1j * t_complex.imag.)

Here is an illustration of these points:

>>> import torch
>>> print (torch.__version__)
1.13.0
>>>
>>> _ = torch.manual_seed (2022)
>>>
>>> # create a complex Linear that happens to have real coefficients -- experimental
>>> linear_layer = torch.nn.Linear (2, 2).to (dtype = torch.complex64)
<path_to_pytorch_install>\torch\nn\modules\module.py:975: UserWarning: Complex modules are a new feature under active development whose design may change, and some modules might not work as expected when using complex tensors as parameters or buffers. Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml if a complex module does not work as expected.
  warnings.warn(
>>>
>>> # input to Linear must also be complex
>>> test = linear_layer (torch.tensor ([[1. + 1.j, 2. + 2.j]]))
>>>
>>> (test * (2 + 3j)).sum().backward (retain_graph = True)        # gradient = None (by default)
>>> linear_layer.weight.grad   # now you get the full complex grad (of real part of "loss")
tensor([[-1.-5.j, -2.-10.j],
        [-1.-5.j, -2.-10.j]])
>>> linear_layer.bias.grad     # now you get the full complex grad (of real part of "loss")
tensor([2.-3.j, 2.-3.j])
>>>
>>> linear_layer.zero_grad()
>>> (test * (2 + 3j)).sum().real.backward (retain_graph = True)   # same as without .real
>>> linear_layer.weight.grad
tensor([[-1.-5.j, -2.-10.j],
        [-1.-5.j, -2.-10.j]])
>>> linear_layer.bias.grad
tensor([2.-3.j, 2.-3.j])
>>>
>>> linear_layer.zero_grad()
>>> (test * (2 + 3j)).sum().imag.backward (retain_graph = True)   # .imag is (of course) different
>>> linear_layer.weight.grad
tensor([[ 5.-1.j, 10.-2.j],
        [ 5.-1.j, 10.-2.j]])
>>> linear_layer.bias.grad
tensor([3.+2.j, 3.+2.j])

Note the sign of the imaginary part of linear_layer.bias.grad. This is
due to the use of d / dz* as the gradient.

I haven’t used functorch, but I believe that it uses autograd under the
hood, with an overlay of “more efficient” loops. So I expect that these
principles still apply and the autograd documentation still tells you the
(piece-wise) details of what is happening.

Best.

K. Frank

1 Like