Why is it so hard to enforce a weight matrix to be orthogonal?

This post and this github page explains how to add an orthogonal constraint onto weight matrix. However, it doesn’t seems to work.

In order to have a matrix A to be orthogonal, we must have (X^T X = I), thus, we can add |X^T X - I| in our loss. Here’s my code:

#make a random vector
X = torch.rand(30,500).to(device)
#make a random orthogonal matrix
rho = torch.nn.init.orthogonal_(torch.empty(500, 500)).to(device)
#X_rotated will be the target vector
X_target = X@rho
#The model is simply a single Linear layer
model = torch.nn.Linear(500,500).to(device)
#intialize the weight to orthogonal
model.weight.data.copy_(torch.nn.init.orthogonal_(torch.empty(500, 500)).to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss(reduction='sum')

reg = 0.001
for t in range(10000):
  param = model.weight.data
  param_flat = param.view(param.shape[0], -1)
  sym = torch.mm(param_flat, torch.t(param_flat))
  sym -= torch.eye(param_flat.shape[0]).to(device)
  orthogonal_loss = reg * sym.abs().sum()
  loss =  loss_fn(model(X),X_target) + orthogonal_loss
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()
  if t%200==0:
    print("loss:{}".format(float(loss)),"orthogonal_loss:{}".format(float(sym.abs().sum())))

The optimal weight for the model is certainly rho, which will gives 0 loss. However, it doesn’t seem to converge to it. The matrix it converges to doesn’t seem to be orthogonal (high orthogonal loss):

step: 0 loss:9965.669921875 orthogonal_loss:0.0056331586092710495
step: 200 loss:9.945926666259766 orthogonal_loss:2980.79150390625
step: 400 loss:3.1001315116882324 orthogonal_loss:3038.67333984375
step: 600 loss:3.163803815841675 orthogonal_loss:3040.3330078125
step: 800 loss:3.425936222076416 orthogonal_loss:3040.452392578125
step: 1000 loss:3.042102575302124 orthogonal_loss:3040.4287109375
step: 1200 loss:3.091557502746582 orthogonal_loss:3040.39208984375
step: 1400 loss:3.171616554260254 orthogonal_loss:3040.3876953125
step: 1600 loss:4.268471717834473 orthogonal_loss:3040.52197265625
step: 1800 loss:4.954420566558838 orthogonal_loss:3040.474365234375
step: 2000 loss:3.115755319595337 orthogonal_loss:3040.40771484375
step: 2200 loss:4.3386921882629395 orthogonal_loss:3040.38623046875
step: 2400 loss:3.266144037246704 orthogonal_loss:3040.4541015625
step: 2600 loss:3.284057140350342 orthogonal_loss:3040.4365234375
step: 2800 loss:4.709336757659912 orthogonal_loss:3040.38427734375
step: 3000 loss:4.440422058105469 orthogonal_loss:3040.4404296875
step: 3200 loss:3.7141575813293457 orthogonal_loss:3040.435546875
step: 3400 loss:3.8447492122650146 orthogonal_loss:3040.53759765625
step: 3600 loss:5.975290775299072 orthogonal_loss:3040.39794921875
step: 3800 loss:3.474747657775879 orthogonal_loss:3040.509521484375
step: 4000 loss:4.279032230377197 orthogonal_loss:3040.54296875
step: 4200 loss:4.369743347167969 orthogonal_loss:3040.31640625
step: 4400 loss:7.692440986633301 orthogonal_loss:3040.627685546875
step: 4600 loss:5.032724380493164 orthogonal_loss:3040.314697265625
step: 4800 loss:7.126654148101807 orthogonal_loss:3040.60693359375
step: 5000 loss:3.818039655685425 orthogonal_loss:3040.35546875
step: 5200 loss:4.2421369552612305 orthogonal_loss:3040.56103515625
step: 5400 loss:6.937448024749756 orthogonal_loss:3040.30712890625
step: 5600 loss:3.442885637283325 orthogonal_loss:3040.4599609375
step: 5800 loss:3.3514583110809326 orthogonal_loss:3040.493408203125
step: 6000 loss:4.078462600708008 orthogonal_loss:3040.5341796875
step: 6200 loss:3.448216199874878 orthogonal_loss:3040.46630859375
step: 6400 loss:4.94446325302124 orthogonal_loss:3040.5009765625
step: 6600 loss:4.0663652420043945 orthogonal_loss:3040.351318359375
step: 6800 loss:4.430315971374512 orthogonal_loss:3040.33837890625
step: 7000 loss:4.341968536376953 orthogonal_loss:3040.444580078125
step: 7200 loss:5.806286811828613 orthogonal_loss:3040.665283203125
step: 7400 loss:4.211328029632568 orthogonal_loss:3040.381103515625
step: 7600 loss:3.4469919204711914 orthogonal_loss:3040.38037109375
step: 7800 loss:4.388216495513916 orthogonal_loss:3040.4833984375
step: 8000 loss:3.740677833557129 orthogonal_loss:3040.37841796875
step: 8200 loss:6.28993034362793 orthogonal_loss:3040.64404296875
step: 8400 loss:4.638856887817383 orthogonal_loss:3040.411376953125
step: 8600 loss:3.8260600566864014 orthogonal_loss:3040.49462890625
step: 8800 loss:4.07125186920166 orthogonal_loss:3040.52685546875
step: 9000 loss:3.7882349491119385 orthogonal_loss:3040.50244140625
step: 9200 loss:3.6767845153808594 orthogonal_loss:3040.515380859375
step: 9400 loss:4.998872756958008 orthogonal_loss:3040.62939453125
step: 9600 loss:4.7679219245910645 orthogonal_loss:3040.3095703125
step: 9800 loss:3.294722318649292 orthogonal_loss:3040.46875

Hi Zeyuyun!

While this “orthogonality penalty” is zero if, and only, if X is orthogonal,
and is positive otherwise, it doesn’t work well for orthogonalizing X
with gradient descent, because its structure doesn’t fit* well with the
natural geometry of the set of orthogonal matrices.

Replacing the sum of absolute values with the sum of squares gives
a much more natural penalty that works well with gradient descent.

Please take a look at the modified version of your script, below, that
successfully applies the sum-of-squares orthogonality penalty to the
optimization of model.weight.

Also, as written, param is no longer part of the computation graph (It has
been set to .data.), so gradients of orthogonal_loss won’t backpropagate
and won’t affect the training of model.weight.

Finally, I substantially increase reg, the weight of the orthogonality penalty
in the total loss.

The problem here is that with only 30 vectors to train on, you don’t have
enough information to reconstruct the 500x500 matrix rho.

There are a couple more issues:

Here you right-multiply X by rho. Not a big deal, but it means that you
will be training model.weight to become equal to rho.T (rather than
rho).

Lastly, orthogonal matrices come in two disconnected sets: those with
determinant +1 and those with determinant -1. Although the optimizer
takes finite steps, these steps are “small,” so the optimization procedure
evolves model.weight continuously from its initial to its final value.

If you happen to choose rho with determinant +1, but also happen to
initialize model.weight with determinant -1, the optimizer will not** be
able to train model.weight to become rho.

*) Crudely, you can think of the geometry of the set of orthogonal matrices
as being like a circle, with a rotational symmetry. The sum of absolute
values singles out the x and y directions, not respecting that symmetry.
The sum of squares respects the rotational symmetry. This is part of the
reason that the sum of squares works better as a penalty.

**) Hypothetically, the optimizer could take a large, fortuitous step that
flips the sign of the determinant of model.weight, but this would be
highly unlikely.

Here is the modified script, with some informative comments:

import torch
torch.__version__

torch.random.manual_seed (2021)

nOrt = 500
nBatch = 30   # not enough rotated vectors to recover 500x500 rho

nEpoch = 1000
nFreq = 100

# make a batch of random vectors
X = torch.rand (nBatch, nOrt)

print ('X.shape =', X.shape)

# make a random orthogonal matrix with determinant +1
rho = torch.nn.init.orthogonal_ (torch.empty (nOrt, nOrt))
if  torch.slogdet (rho)[0] < 0.0:
    rho *= -1   # flip determinant -- assumes that rho is odd-dimensional

# the target vector will be the rotated X
# note that X is right-multiplied by rho
# this will cause model.weight to train to match rho.T, rather than rho itself
X_target = X @ rho

# the model is simply a single Linear layer (with bias)
model = torch.nn.Linear (nOrt, nOrt)
# initialize model.weight to another orthogonal matrix with determinant +1 to match det (rho)
wInit = torch.nn.init.orthogonal_ (torch.empty (nOrt, nOrt))
if  torch.slogdet (wInit)[0] < 0.0:
    wInit *= -1   # flip determinant -- assumes that wInit is odd-dimensional

with torch.no_grad():
    _ = model.weight.copy_ (wInit)

optimizer = torch.optim.Adam (model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss (reduction='sum')

# reg = 0.001   # weight orthogonality-penalty much more heavily
reg = 1000.0
for t in range (nEpoch):
    # param = model.weight.data   # breaks computation graph --  model.weight not optimized
    param = model.weight
    sym = torch.mm (param, torch.t (param))
    sym -= torch.eye (param.shape[0])
    # ls_ort = sym.abs().sum()   # poor match to geometry of orthogonal matrices
    ls_ort = sym.pow (2.0).sum()
    ls_fit = loss_fn (model(X), X_target)
    ls_tot = ls_fit + reg * ls_ort
    with torch.no_grad():
        ldet = torch.slogdet (param)[1]
    optimizer.zero_grad()
    ls_tot.backward()
    optimizer.step()
    if t <= 10  or  t % nFreq == 0  or  t >= nEpoch - 1:
        print ('epoch:{:4d}'.format (t), ' ls_tot:{:9.3f}'.format (float (ls_tot)), ' ls_fit:{:9.3f}'.format (float (ls_fit)), ' ls_ort: {:.2e}'.format (float (ls_ort)), ' ldet: {: .1e}'.format (float (ldet)))

# check rho and model.weight determinants
print ('slogdet (rho)    =  {: 1.0f}  {: .3e}'.format (torch.slogdet (rho)[0], torch.slogdet(rho)[1]))
print ('slogdet (param)  =  {: 1.0f}  {: .3e}'.format (torch.slogdet (param)[0], torch.slogdet(param)[1]))

# check fit of model.weight to rho
with torch.no_grad():
    fit = (param - rho).abs().sum()
    fitT = (param - rho.T).abs().sum()

print ('fit = {:13.5f}, fitT = {:13.5f}'.format (fit, fitT))

And here is its output:

X.shape = torch.Size([30, 500])
epoch:   0  ls_tot: 9968.365  ls_fit: 9968.365  ls_ort: 2.27e-10  ldet:  6.9e-05
epoch:   1  ls_tot: 6984.131  ls_fit: 6536.737  ls_ort: 4.47e-01  ldet: -4.1e-01
epoch:   2  ls_tot: 5456.405  ls_fit: 4554.480  ls_ort: 9.02e-01  ldet: -9.2e-01
epoch:   3  ls_tot: 4708.714  ls_fit: 3511.573  ls_ort: 1.20e+00  ldet: -1.4e+00
epoch:   4  ls_tot: 4356.467  ls_fit: 3003.020  ls_ort: 1.35e+00  ldet: -1.8e+00
epoch:   5  ls_tot: 4166.570  ls_fit: 2764.880  ls_ort: 1.40e+00  ldet: -2.1e+00
epoch:   6  ls_tot: 4033.207  ls_fit: 2654.856  ls_ort: 1.38e+00  ldet: -2.1e+00
epoch:   7  ls_tot: 3915.477  ls_fit: 2601.749  ls_ort: 1.31e+00  ldet: -1.9e+00
epoch:   8  ls_tot: 3786.879  ls_fit: 2565.707  ls_ort: 1.22e+00  ldet: -1.7e+00
epoch:   9  ls_tot: 3633.864  ls_fit: 2528.301  ls_ort: 1.11e+00  ldet: -1.4e+00
epoch:  10  ls_tot: 3459.893  ls_fit: 2488.102  ls_ort: 9.72e-01  ldet: -1.2e+00
epoch: 100  ls_tot:  633.859  ls_fit:  630.421  ls_ort: 3.44e-03  ldet: -6.8e-02
epoch: 200  ls_tot:   47.161  ls_fit:   46.997  ls_ort: 1.64e-04  ldet: -6.4e-03
epoch: 300  ls_tot:    3.470  ls_fit:    3.461  ls_ort: 9.53e-06  ldet: -8.9e-04
epoch: 400  ls_tot:    0.318  ls_fit:    0.317  ls_ort: 7.55e-07  ldet: -2.2e-04
epoch: 500  ls_tot:    0.037  ls_fit:    0.036  ls_ort: 7.71e-07  ldet:  2.4e-05
epoch: 600  ls_tot:    0.014  ls_fit:    0.011  ls_ort: 3.07e-06  ldet: -2.9e-05
epoch: 700  ls_tot:    0.005  ls_fit:    0.004  ls_ort: 7.76e-07  ldet:  5.3e-05
epoch: 800  ls_tot:    0.007  ls_fit:    0.005  ls_ort: 1.90e-06  ldet: -6.9e-05
epoch: 900  ls_tot:    0.028  ls_fit:    0.016  ls_ort: 1.12e-05  ldet: -1.4e-04
epoch: 999  ls_tot:    0.021  ls_fit:    0.013  ls_ort: 8.38e-06  ldet: -5.8e-05
slogdet (rho)    =   1   5.507e-05
slogdet (param)  =   1  -4.864e-05
fit =   12627.76172, fitT =   12255.37109

This illustrates that both the “fitting” loss (MSELoss) and the orthogonality
penalty train down to low values. Note, however, that, as expected,
model.weight does not train to match rho (nor rho.T) (fit and fitT,
respectively) because the 30 sample vectors are not enough to fully
reconstruct rho.

Best.

K. Frank

2 Likes

WOW! Thank you so much again for the extremely thorough answer. I am still very new to pytorch, sorry for making many mistakes.

Would you mind explain a little bit more on why square loss (Frobenius Norm) is better for training? I can understand orthogonal group are circles. But I can’t understand why that has anything to do this with square loss.

Hi Zeyuyun!

I don’t know if what I said can be made precise – for me it is analogy and
intuition.

Go back to the circle – let’s say of radius 1 in the x-y plane. Consider
a point off the circle, (0.01, 1.50). The closest point on the circle is
approximately (0.0, 1.0). So x = 0.01 is nearly right, while y = 1.50
is a ways off.

With the sum-of-absolute-values, the gradient will be of magnitude one
in both directions. So a gradient-descent step will change both x and y
by the same amount, even though y should be changed much more
than x. With the sum-of-squares, the gradient will be directed outward
from the center of the circle, so a gradient-descent step will move the
point (x, y) directly towards the center of the circle, which is the same
as moving it directly towards the nearest point on the circle (which is
what we want).

Now, depending on the learning rate, we might not move all the way to
the circle, staying on the outside. Or we might overshoot, jumping to the
inside of the circle, but at least we’re moving in the right direction.

(Also, the sum-of-squares is “softer.” As you get closer to the circle, the
gradients become smaller, and you take smaller steps, which tends to be
good. With the sum-of-absolute-values the magnitude of both components
of the gradient will always be 1, regardless of how close you are to the
circle, so you can easily get in a situation where you keep jumping back
and forth between the inside and outside of the circle, without actually
getting closer to the circle (unless you have a scheme for reducing your
learning rate while you are doing this).)

The geometry of 500x500 orthogonal matrices has much more structure
than that of a circle, but this is the basic idea of what is going on. (At least
this is what I think is going on …)

Best.

K. Frank

1 Like

Thanks for the detailed intuitive on the gradient! Now it makes perfect sense.

Again, thanks you for the help overall. It’s super helpful!