Hi Akusa!
You can add an orthogonal penalty to your loss function to push var1
towards being orthogonal.
Note, doing so will not force var1
to be exactly orthogonal – if the rest
of your loss function prefers a non-orthogonal var1
, the optimization
procedure will make a trade-off between getting the lowest value for your
regular loss and having var1
be orthogonal. You can tune this trade-off
by multiplying the penalty term with a weighting factor. As you increase
the weight of the penalty, the closer to orthogonal var1
will become.
This script adds an orthogonality penalty to your code fragment:
import torch
print (torch.__version__)
_ = torch.random.manual_seed (2021)
def ortho_penalty (t):
return ((t @t.T - torch.eye (t.shape[0]))**2).sum()
n = 5
D = 7
lr = 0.1
ite_max = 200
var1 = torch.randn ((n, D), requires_grad = True) # Tensor of shape n x D
optimizer = torch.optim.Adam([var1], lr = lr)
for ite in range(ite_max):
optimizer.zero_grad()
loss = ortho_penalty (var1)
if ite == ite_max - 1 or ite % 20 == 0:
print ('loss =', loss.item())
loss.backward()
optimizer.step()
print ('var1 = ...\n', var1)
print ('var1 @ var1.T = ...\n', var1 @ var1.T)
Here is its output:
1.7.1
loss = 272.65423583984375
loss = 0.8204055428504944
loss = 0.2746223211288452
loss = 0.055152103304862976
loss = 0.0036679445765912533
loss = 0.0007598181255161762
loss = 7.743481546640396e-05
loss = 1.1900866411451716e-05
loss = 4.0019048697104154e-07
loss = 1.5665852970414562e-07
loss = 1.9611301027566697e-08
var1 = ...
tensor([[-0.5975, -0.3254, 0.2185, -0.0053, -0.3276, -0.4471, 0.4268],
[-0.1584, -0.7288, -0.2472, -0.2933, -0.1118, 0.2660, -0.4619],
[-0.4144, 0.5079, 0.3184, -0.3310, -0.2745, 0.0373, -0.5317],
[ 0.5968, -0.0300, -0.0399, -0.1651, -0.6691, -0.3961, -0.0974],
[-0.0128, 0.0696, 0.0430, 0.0669, -0.5315, 0.7458, 0.3873]],
requires_grad=True)
var1 @ var1.T = ...
tensor([[ 1.0000e+00, -2.4647e-05, -1.3560e-05, 3.6400e-05, 3.8207e-05],
[-2.4647e-05, 1.0000e+00, -7.5102e-06, -9.2573e-06, -1.5244e-05],
[-1.3545e-05, -7.5102e-06, 1.0000e+00, 5.3346e-06, -4.1515e-05],
[ 3.6404e-05, -9.2536e-06, 5.3346e-06, 1.0000e+00, 1.0524e-05],
[ 3.8207e-05, -1.5244e-05, -4.1515e-05, 1.0554e-05, 9.9999e-01]],
grad_fn=<MmBackward>)
Best.
K. Frank