How to implement a custom loss function which include frobenius norm?

Hello Yuchen!

Complex tensors are still a work in progress in pytorch, with more
functionality (and more functionality that is actually correct) being
added in successive versions.

Note that as of version 1.6.0, I don’t believe that the pytorch optimizers
accept complex Parameters, so to use pytorch’s complex machinery,
you will have to either use real Parameters that you combine into
complex tensors or write your own complex-aware optimizer.

All in all, depending on what you are doing, it might be safest to
represent the real and imaginary parts of your complex tensors as
separate real tensors and carry out the complex arithmetic “by hand.”

The Frobenius norm of a (complex) matrix is simply the square root
of the sum of the squares of the (absolute values of the) individual
matrix elements. Pythorch’s tensor operations can do this* reasonably
straightforwardly.

*) With the proviso that complex tensors are a work in progress.

Note that as of version 1.6.0, torch.norm() is incorrect for complex
tensors – it uses the squares, rather than the squared absolute values,
of the matrix elements.

Here is a script that illustrates calculating and backpropagating the
Frobenius norm:

import torch
torch.__version__

_ = torch.random.manual_seed (2020)

x = torch.randn ([2, 3])
print ('x = ...\n', x)
print ('torch.norm (x) =', torch.norm (x))   # okay
z = torch.randn ([2, 3], dtype = torch.cfloat)
print ('z = ...\n', z)
print ('torch.norm (z) =', torch.norm (z))   # oops, should be positive real
z.requires_grad = True
znorm = torch.sqrt ((z * z.conj()).sum())
print ('znorm =', znorm)
znorm.backward()
print ('z.grad =', z.grad)
z.grad.zero_()
znormb = torch.sqrt ((torch.real (z)**2).sum() + (torch.imag (z)**2).sum())
print ('znormb =', znormb)
znormb.backward()
print ('z.grad =', z.grad)

And here is its (version 1.6.0) output:

x = ...
 tensor([[ 1.2372, -0.9604,  1.5415],
        [-0.4079,  0.8806,  0.0529]])
torch.norm (x) = tensor(2.4029)
z = ...
 tensor([[ 0.0531+0.3378j, -0.4779-1.5195j, -0.8105-0.1923j],
        [ 0.7118-0.0294j, -0.9088-0.3499j, -0.9167-0.8840j]])
torch.norm (z) = tensor(1.3644+1.4714j)
znorm = tensor(2.5350+0.j, grad_fn=<SqrtBackward>)
z.grad = tensor([[ 0.0210-0.1332j, -0.1885+0.5994j, -0.3197+0.0759j],
        [ 0.2808+0.0116j, -0.3585+0.1380j, -0.3616+0.3487j]])
znormb = tensor(2.5350, grad_fn=<SqrtBackward>)
z.grad = tensor([[ 0.0210-0.1332j, -0.1885+0.5994j, -0.3197+0.0759j],
        [ 0.2808+0.0116j, -0.3585+0.1380j, -0.3616+0.3487j]])

Be careful, however, with what you do with a complex gradient. You
will have to take the complex conjugate of the gradient to use it with
gradient-descent optimization.

This script illustrates this behavior by minimizing the Frobenius norm
with gradient descent:

import torch
torch.__version__

_ = torch.random.manual_seed (2020)

za = torch.randn ([2, 3], dtype = torch.cfloat)
zb = za.clone()

lr = 0.001

# gradient descent A
za.requires_grad = True
print ('za =', za)
for  i in range (10001):
    if  not za.grad == None:  _ = za.grad.zero_()
    znorm = torch.sqrt ((za * za.conj()).sum())
    znorm.backward()
    with torch.no_grad():
        _ = za.copy_ (za - lr * za.grad)   # doesn't converge
    if  i % 1000 == 0:  print ('znorm =', znorm)

print ('za =', za)

# gradient descent B
zb.requires_grad = True
print ('zb =', zb)
for  i in range (10001):
    if  not zb.grad == None:  _ = zb.grad.zero_()
    znorm = torch.sqrt ((zb * zb.conj()).sum())
    znorm.backward()
    with torch.no_grad():
        _ = zb.copy_ (zb - lr * zb.grad.conj())   # use conjugate of gradient to get convergence
    if  i % 1000 == 0:  print ('znorm =', znorm)

print ('zb =', zb)

And here is its output:

za = tensor([[ 0.8749-0.6791j,  1.0900-0.2884j,  0.6227+0.0374j],
        [ 0.0531+0.3378j, -0.4779-1.5195j, -0.8105-0.1923j]],
       requires_grad=True)
znorm = tensor(2.4970+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(2.8249+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(3.6008+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(4.5226+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(5.4901+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(6.4745+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(7.4661+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(8.4612+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(9.4581+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(10.4563+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(11.4551+0.j, grad_fn=<SqrtBackward>)
za = tensor([[ 0.1324-4.4859j,  0.1649-1.9052j,  0.0942+0.2472j],
        [ 0.0080+2.2312j, -0.0723-10.0381j, -0.1226-1.2704j]],
       requires_grad=True)
zb = tensor([[ 0.8749-0.6791j,  1.0900-0.2884j,  0.6227+0.0374j],
        [ 0.0531+0.3378j, -0.4779-1.5195j, -0.8105-0.1923j]],
       requires_grad=True)
znorm = tensor(2.4970+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(1.4970+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(0.4970+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(0.0010+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(0.0010+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(0.0010+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(0.0010+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(0.0010+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(0.0010+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(0.0010+0.j, grad_fn=<SqrtBackward>)
znorm = tensor(0.0010+0.j, grad_fn=<SqrtBackward>)
zb = tensor([[ 7.4477e-06-5.7811e-06j,  9.2789e-06-2.4552e-06j,
          5.3007e-06+3.1863e-07j],
        [ 4.5219e-07+2.8753e-06j, -4.0686e-06-1.2936e-05j,
         -6.9001e-06-1.6372e-06j]], requires_grad=True)

You can see that without taking its complex conjugate, the gradient
pushes the imaginary parts of the tensor away from zero so that the
Frobenius norm grows.

As long as you write your custom loss function using pytorch tensor
operations, you will get autograd and backpropagation (but not
complex optimization) “for free.” You won’t have to write an explicit
.backward() function for your loss function.

This will certainly be true if your represent the real and imaginary parts
of your complex tensors as explicit real tensors.

However, I think its worth trying pytorch’s complex tensors, but if you
decide to go this route, you should use the latest version of pytorch
that otherwise works for you, and test your complex manipulations
carefully, especially the various functions you use and backpropagation.

Good luck.

K. Frank