Is there an orthogonal procrustes for PyTorch?

inspired from this paper as a potentiall better similarity representation metric: [2108.01661] Grounding Representation Similarity with Statistical Testing


note scipy does have it: scipy.linalg.orthogonal_procrustes — SciPy v1.7.1 Manual

this is probably enough:

ref: ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu at master · brando90/ultimate-utils · GitHubinit.py

def orthogonal_procrustes_distance(x1: Tensor, x2: Tensor, normalize: bool = False) -> Tensor:
    """
    Computes the orthoginal procrustes distance.
    If normalized then the answer is divided by 2 so that it's in the interval [0, 1].

    Expected input:
        - two matrices e.g.
            - two weight matrices of size [num_weights1, num_weights2]
            - or two matrices of activations [batch_size, dim_of_layer] (used by paper [1])

    d_proc(A*, B) = ||A||^2_F + ||B||^2_F - 2||A^T B||_*
    || . ||_* = nuclear norm = sum of singular values sum_i sig(A)_i = ||A||_*

    Note: - this only works for matrices. So it's works as a metric for FC and transformers (or at least previous work
    only used it for transformer [1] which have FC and no convolutions.
    - note,

    ref:
    - [1] https://arxiv.org/abs/2108.01661
    - [2] https://discuss.pytorch.org/t/is-there-an-orthogonal-procrustes-for-pytorch/131365
    - [3] https://ee227c.github.io/code/lecture5.html#nuclear-norm

    :param x1:
    :param x2:
    :return:
    """
    from torch.linalg import norm
    # x1x2 = torch.bmm(x1, x2)
    x1x2 = x1.t() @ x2
    d: Tensor = norm(x1, 'fro') + norm(x2, 'fro') - 2 * norm(x1x2, 'nuc')
    d: Tensor = d / 2.0 if normalize else d
    return d

def orthogonal_procrustes_similairty(x1: Tensor, x2: Tensor, normalize: bool = False) -> Tensor:
    """
    Returns orthogonal procurstes similarity. If normalized then output is in invertval [0, 1] and if not then output
    is in interval [0, 1]. See orthogonal_procrustes_distance for details and references.

    :param x1:
    :param x2:
    :param normalize:
    :return:
    """
    d = orthogonal_procrustes_distance(x1, x2, normalize)
    sim: Tensor = 1.0 - d if normalize else 2.0 - d
    return sim

e.g.


def op_test():
    from uutils.torch_uu.models import hardcoded_3_layer_model

    force = True
    # force = False
    mdl1 = hardcoded_3_layer_model(5, 1)
    mdl2 = hardcoded_3_layer_model(5, 1)
    batch_size = 4
    X = torch.randn(batch_size, 5)
    import copy
    from uutils.torch_uu import l2_sim_torch
    # get [..., s_l, ...] sim per layer (for this data set)
    modules = zip(mdl1.named_children(), mdl2.named_children())
    sims_per_layer = []
    out1 = X
    out2 = X
    for (name1, m1), (name2, m2) in modules:
        # if name1 in layer_names:
        if 'ReLU' in name1 or force:  # only compute on activation
            out1 = m1(out1)
            m2_callable = copy.deepcopy(m1)
            m2_callable.load_state_dict(m2.state_dict())
            out2 = m2_callable(out2)
            sim = l2_sim_torch(out1, out2, sim_type='op_torch')
            sims_per_layer.append((name1,sim))
    pprint(sims_per_layer)

# -- __main__

if __name__ == '__main__':
    # test_ned()
    # test_tensorify()
    # test_compressed_r2_score()
    # test_topk_accuracy_and_accuracy()
    # test_simple_determinism()
    op_test()
    print('Done\a')

out:


[('fc0', tensor(5.7326, grad_fn=<RsubBackward1>)),
 ('ReLU0', tensor(2.6101, grad_fn=<RsubBackward1>)),
 ('fc1', tensor(3.8898, grad_fn=<RsubBackward1>)),
 ('ReLU2', tensor(1.3644, grad_fn=<RsubBackward1>)),
 ('fc3', tensor(1.5007, grad_fn=<RsubBackward1>))]
Donea

ultimate-utils/ultimate-utils-proj-src/uutils/torch_uu at bfadaf93594432ee05d669378025e36c5296af02 · brando90/ultimate-utils · GitHubinit.py#L1321