# Add orthogonal constraint optimizer

Say you want to optimize parameters of your model where your parameters are a tensor n x D.

``````var1 = torch.tensor(..., requires_grad = True) # Tensor of shape n x D
optimizer = torch.optim.Adam([var1], lr = 1e-5)
for ite in range(ite_max):
loss = loss_function(var1)
optimizer.step()
``````

I want to add an orthogonal constraint on my parameter var1, that is,

``````var1[:,i].T * var1[:,j] = kronecker_ij
``````

where kronecker_ij = 1 if i = j, 0 otherwise.

How should I do it?

Hi Akusa!

You can add an orthogonal penalty to your loss function to push `var1`
towards being orthogonal.

Note, doing so will not force `var1` to be exactly orthogonal – if the rest
of your loss function prefers a non-orthogonal `var1`, the optimization
procedure will make a trade-off between getting the lowest value for your
regular loss and having `var1` be orthogonal. You can tune this trade-off
by multiplying the penalty term with a weighting factor. As you increase
the weight of the penalty, the closer to orthogonal `var1` will become.

This script adds an orthogonality penalty to your code fragment:

``````import torch
print (torch.__version__)

_ = torch.random.manual_seed (2021)

def ortho_penalty (t):
return ((t @t.T - torch.eye (t.shape[0]))**2).sum()

n = 5
D = 7
lr = 0.1
ite_max = 200

var1 = torch.randn ((n, D), requires_grad = True) # Tensor of shape n x D
optimizer = torch.optim.Adam([var1], lr = lr)
for ite in range(ite_max):
loss = ortho_penalty (var1)
if ite == ite_max - 1  or  ite % 20 == 0:
print ('loss =', loss.item())
loss.backward()
optimizer.step()

print ('var1 = ...\n', var1)
print ('var1 @ var1.T = ...\n', var1 @ var1.T)
``````

Here is its output:

``````1.7.1
loss = 272.65423583984375
loss = 0.8204055428504944
loss = 0.2746223211288452
loss = 0.055152103304862976
loss = 0.0036679445765912533
loss = 0.0007598181255161762
loss = 7.743481546640396e-05
loss = 1.1900866411451716e-05
loss = 4.0019048697104154e-07
loss = 1.5665852970414562e-07
loss = 1.9611301027566697e-08
var1 = ...
tensor([[-0.5975, -0.3254,  0.2185, -0.0053, -0.3276, -0.4471,  0.4268],
[-0.1584, -0.7288, -0.2472, -0.2933, -0.1118,  0.2660, -0.4619],
[-0.4144,  0.5079,  0.3184, -0.3310, -0.2745,  0.0373, -0.5317],
[ 0.5968, -0.0300, -0.0399, -0.1651, -0.6691, -0.3961, -0.0974],
[-0.0128,  0.0696,  0.0430,  0.0669, -0.5315,  0.7458,  0.3873]],
var1 @ var1.T = ...
tensor([[ 1.0000e+00, -2.4647e-05, -1.3560e-05,  3.6400e-05,  3.8207e-05],
[-2.4647e-05,  1.0000e+00, -7.5102e-06, -9.2573e-06, -1.5244e-05],
[-1.3545e-05, -7.5102e-06,  1.0000e+00,  5.3346e-06, -4.1515e-05],
[ 3.6404e-05, -9.2536e-06,  5.3346e-06,  1.0000e+00,  1.0524e-05],
[ 3.8207e-05, -1.5244e-05, -4.1515e-05,  1.0554e-05,  9.9999e-01]],
`loss = loss_function(var1) + w*ortho_penalty(var1)`
where `w > 0 ` is the penalty weight.