# Is there an orthogonal procrustes for PyTorch?

inspired from this paper as a potentiall better similarity representation metric: [2108.01661] Grounding Representation Similarity with Statistical Testing

note scipy does have it: scipy.linalg.orthogonal_procrustes — SciPy v1.7.1 Manual

this is probably enough:

``````def orthogonal_procrustes_distance(x1: Tensor, x2: Tensor, normalize: bool = False) -> Tensor:
"""
Computes the orthoginal procrustes distance.
If normalized then the answer is divided by 2 so that it's in the interval [0, 1].

Expected input:
- two matrices e.g.
- two weight matrices of size [num_weights1, num_weights2]
- or two matrices of activations [batch_size, dim_of_layer] (used by paper [1])

d_proc(A*, B) = ||A||^2_F + ||B||^2_F - 2||A^T B||_*
|| . ||_* = nuclear norm = sum of singular values sum_i sig(A)_i = ||A||_*

Note: - this only works for matrices. So it's works as a metric for FC and transformers (or at least previous work
only used it for transformer [1] which have FC and no convolutions.
- note,

ref:
- [1] https://arxiv.org/abs/2108.01661
- [2] https://discuss.pytorch.org/t/is-there-an-orthogonal-procrustes-for-pytorch/131365
- [3] https://ee227c.github.io/code/lecture5.html#nuclear-norm

:param x1:
:param x2:
:return:
"""
from torch.linalg import norm
# x1x2 = torch.bmm(x1, x2)
x1x2 = x1.t() @ x2
d: Tensor = norm(x1, 'fro') + norm(x2, 'fro') - 2 * norm(x1x2, 'nuc')
d: Tensor = d / 2.0 if normalize else d
return d

def orthogonal_procrustes_similairty(x1: Tensor, x2: Tensor, normalize: bool = False) -> Tensor:
"""
Returns orthogonal procurstes similarity. If normalized then output is in invertval [0, 1] and if not then output
is in interval [0, 1]. See orthogonal_procrustes_distance for details and references.

:param x1:
:param x2:
:param normalize:
:return:
"""
d = orthogonal_procrustes_distance(x1, x2, normalize)
sim: Tensor = 1.0 - d if normalize else 2.0 - d
return sim
``````

e.g.

``````
def op_test():
from uutils.torch_uu.models import hardcoded_3_layer_model

force = True
# force = False
mdl1 = hardcoded_3_layer_model(5, 1)
mdl2 = hardcoded_3_layer_model(5, 1)
batch_size = 4
X = torch.randn(batch_size, 5)
import copy
from uutils.torch_uu import l2_sim_torch
# get [..., s_l, ...] sim per layer (for this data set)
modules = zip(mdl1.named_children(), mdl2.named_children())
sims_per_layer = []
out1 = X
out2 = X
for (name1, m1), (name2, m2) in modules:
# if name1 in layer_names:
if 'ReLU' in name1 or force:  # only compute on activation
out1 = m1(out1)
m2_callable = copy.deepcopy(m1)
m2_callable.load_state_dict(m2.state_dict())
out2 = m2_callable(out2)
sim = l2_sim_torch(out1, out2, sim_type='op_torch')
sims_per_layer.append((name1,sim))
pprint(sims_per_layer)

# -- __main__

if __name__ == '__main__':
# test_ned()
# test_tensorify()
# test_compressed_r2_score()
# test_topk_accuracy_and_accuracy()
# test_simple_determinism()
op_test()
print('Done\a')
``````

out:

``````
[('fc0', tensor(5.7326, grad_fn=<RsubBackward1>)),
('ReLU0', tensor(2.6101, grad_fn=<RsubBackward1>)),
('fc1', tensor(3.8898, grad_fn=<RsubBackward1>)),
('ReLU2', tensor(1.3644, grad_fn=<RsubBackward1>)),
('fc3', tensor(1.5007, grad_fn=<RsubBackward1>))]
Donea

``````