This post and this github page explains how to add an orthogonal constraint onto weight matrix. However, it doesn’t seems to work.
In order to have a matrix A to be orthogonal, we must have (X^T X = I), thus, we can add |X^T X - I| in our loss. Here’s my code:
#make a random vector
X = torch.rand(30,500).to(device)
#make a random orthogonal matrix
rho = torch.nn.init.orthogonal_(torch.empty(500, 500)).to(device)
#X_rotated will be the target vector
X_target = X@rho
#The model is simply a single Linear layer
model = torch.nn.Linear(500,500).to(device)
#intialize the weight to orthogonal
model.weight.data.copy_(torch.nn.init.orthogonal_(torch.empty(500, 500)).to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss(reduction='sum')
reg = 0.001
for t in range(10000):
param = model.weight.data
param_flat = param.view(param.shape[0], -1)
sym = torch.mm(param_flat, torch.t(param_flat))
sym -= torch.eye(param_flat.shape[0]).to(device)
orthogonal_loss = reg * sym.abs().sum()
loss = loss_fn(model(X),X_target) + orthogonal_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
if t%200==0:
print("loss:{}".format(float(loss)),"orthogonal_loss:{}".format(float(sym.abs().sum())))
The optimal weight for the model is certainly rho, which will gives 0 loss. However, it doesn’t seem to converge to it. The matrix it converges to doesn’t seem to be orthogonal (high orthogonal loss):
step: 0 loss:9965.669921875 orthogonal_loss:0.0056331586092710495
step: 200 loss:9.945926666259766 orthogonal_loss:2980.79150390625
step: 400 loss:3.1001315116882324 orthogonal_loss:3038.67333984375
step: 600 loss:3.163803815841675 orthogonal_loss:3040.3330078125
step: 800 loss:3.425936222076416 orthogonal_loss:3040.452392578125
step: 1000 loss:3.042102575302124 orthogonal_loss:3040.4287109375
step: 1200 loss:3.091557502746582 orthogonal_loss:3040.39208984375
step: 1400 loss:3.171616554260254 orthogonal_loss:3040.3876953125
step: 1600 loss:4.268471717834473 orthogonal_loss:3040.52197265625
step: 1800 loss:4.954420566558838 orthogonal_loss:3040.474365234375
step: 2000 loss:3.115755319595337 orthogonal_loss:3040.40771484375
step: 2200 loss:4.3386921882629395 orthogonal_loss:3040.38623046875
step: 2400 loss:3.266144037246704 orthogonal_loss:3040.4541015625
step: 2600 loss:3.284057140350342 orthogonal_loss:3040.4365234375
step: 2800 loss:4.709336757659912 orthogonal_loss:3040.38427734375
step: 3000 loss:4.440422058105469 orthogonal_loss:3040.4404296875
step: 3200 loss:3.7141575813293457 orthogonal_loss:3040.435546875
step: 3400 loss:3.8447492122650146 orthogonal_loss:3040.53759765625
step: 3600 loss:5.975290775299072 orthogonal_loss:3040.39794921875
step: 3800 loss:3.474747657775879 orthogonal_loss:3040.509521484375
step: 4000 loss:4.279032230377197 orthogonal_loss:3040.54296875
step: 4200 loss:4.369743347167969 orthogonal_loss:3040.31640625
step: 4400 loss:7.692440986633301 orthogonal_loss:3040.627685546875
step: 4600 loss:5.032724380493164 orthogonal_loss:3040.314697265625
step: 4800 loss:7.126654148101807 orthogonal_loss:3040.60693359375
step: 5000 loss:3.818039655685425 orthogonal_loss:3040.35546875
step: 5200 loss:4.2421369552612305 orthogonal_loss:3040.56103515625
step: 5400 loss:6.937448024749756 orthogonal_loss:3040.30712890625
step: 5600 loss:3.442885637283325 orthogonal_loss:3040.4599609375
step: 5800 loss:3.3514583110809326 orthogonal_loss:3040.493408203125
step: 6000 loss:4.078462600708008 orthogonal_loss:3040.5341796875
step: 6200 loss:3.448216199874878 orthogonal_loss:3040.46630859375
step: 6400 loss:4.94446325302124 orthogonal_loss:3040.5009765625
step: 6600 loss:4.0663652420043945 orthogonal_loss:3040.351318359375
step: 6800 loss:4.430315971374512 orthogonal_loss:3040.33837890625
step: 7000 loss:4.341968536376953 orthogonal_loss:3040.444580078125
step: 7200 loss:5.806286811828613 orthogonal_loss:3040.665283203125
step: 7400 loss:4.211328029632568 orthogonal_loss:3040.381103515625
step: 7600 loss:3.4469919204711914 orthogonal_loss:3040.38037109375
step: 7800 loss:4.388216495513916 orthogonal_loss:3040.4833984375
step: 8000 loss:3.740677833557129 orthogonal_loss:3040.37841796875
step: 8200 loss:6.28993034362793 orthogonal_loss:3040.64404296875
step: 8400 loss:4.638856887817383 orthogonal_loss:3040.411376953125
step: 8600 loss:3.8260600566864014 orthogonal_loss:3040.49462890625
step: 8800 loss:4.07125186920166 orthogonal_loss:3040.52685546875
step: 9000 loss:3.7882349491119385 orthogonal_loss:3040.50244140625
step: 9200 loss:3.6767845153808594 orthogonal_loss:3040.515380859375
step: 9400 loss:4.998872756958008 orthogonal_loss:3040.62939453125
step: 9600 loss:4.7679219245910645 orthogonal_loss:3040.3095703125
step: 9800 loss:3.294722318649292 orthogonal_loss:3040.46875