Is it possible to use `parametrizations.orthogonal` inside a parametrization?

I want to do learning constrained to symmetric positive definite (SPD) matrices. For that I want to use pytorch parametrizations.

I am currently using the spectral decomposition of the SPD matrices, with the following implementation:

import torch
import torch.nn as nn

class SPD(nn.Module):
    """Constrains the matrix to be symmetric positive definite."""

    def __init__(self):
        super().__init__()

    def forward(self, X, Y):
        # Turn vector X into orthogonal matrix
        eigvec = vec_2_orthogonal_matrix(X)
        # Turn vector Y into positive vector
        eigvals = softmax(Y)
        # Generate SPD matrix
        SPD = torch.einsum('ij,j,jk->ik', Q, eigvals, Q.t())
        return SPD

    def right_inverse(self, SPD):
        # Take spectral decomposition of matrix
        eigvals, Q = torch.linalg.eigh(SPD)
        # Convert orthogonal matrix to vector
        X = triu_2_vec(
          orthogonal_2_skew(Q)
        )
        # Convert positive vector to vector
        Y = inv_softmax(eigvals)
        return X, Y

What this code does is that it turns an unconstrained vector X into an orthonormal matrix called eigvecs of size (n, n) using the matrix exponential, and an unconstrained vector Y into a vector called eigvals of size (n) with all positive elements. eigvecs and eigvals are then used to put together the SPD matrix.

When I define my model class, I then use parametrize.register_parametrization(self, "covariance", SPD()), and the parameter covariance is constrained to be SPD.

This code currently works, but I would like to make use of PyTorch’s already implemented torch.nn.utils.parametrizations.orthogonal() to make eigvec orthogonal, rather than my implementation. This is because of my understanding that parametrizations.orthogonal() uses trivializations which help with convergence.

However, it is unclear to me how I could use parametrizations.orthogonal() inside another parametrization. This is because parametrizations.orthogonal() is something that is used to constrain an nn.Parameter to be orthogonal, and I don’t have any nn.Parameter in my implementation above.

How could I use parametrizations.orthogonal() in the example above?

It was possible to do what I wanted. For this, we just need to define the orthogonal matrix as an nn.Parameter inside of the parametrization class. The following code works for this:

class SPD(nn.Module):
    """Constrains the matrix to be symmetric positive definite."""

    def __init__(self, n_dim):
        super().__init__()
        self.ortho = nn.Parameter(torch.eye(n_dim))
        orthogonal(self, "ortho")
        self.ortho = torch.eye(n_dim)

    def forward(self, Y):
        """ Make SPD matrix."""
        # Turn vector Y into positive vector
        eigvals = softmax(Y)
        # Generate SPD matrix
        SPD = torch.einsum(
          'ij,j,jk->ik', self.ortho, eigvals, self.ortho.t()
        )
        return SPD

    def right_inverse(self, SPD):
        # Take spectral decomposition of matrix
        eigvals, ortho = torch.linalg.eigh(SPD)
        # Update orthogonal matrix
        self.ortho = ortho
        # Convert positive vector to vector
        Y = inv_softmax(eigvals)
        return Y

Inside our model, we can then call parametrize.register_parametrization(self, "spd", SPD(n_dim)) and the parameter self.spd will be constrained to be SPD (note that we need to pass n_dim when we call parametrize.register_parametrization).