PyTorch: Square root of a positive semi-definite matrix

Using PyTorch, I am wanting to work out the square root of a positive semi-definite matrix. I googled around for a PyTorch implementation but can’t seem to find the right one.

This is what I have found:

  • https://github.com/steveli/pytorch-sqrtm (this implementation appears to only work for positive definite matrices. I am after an implementation which works for positive semi-definite matrices).
  • https://github.com/pytorch/pytorch/issues/25481 (this implementation also appears to only work for positive definite matrices. Also, the issue is still open, so I guess they haven’t finalized a final version yet).
  • https://github.com/msubhransu/matrix-sqrt (according to the 2nd Github link above, this implementation is not fully PyTorch, it uses PyTorch for backward pass and Scipy for forward pass. This implementation doesn’t say anything about positive definite or positive semi-definite matrices. Also, when I had a looked at the PyTorch code, I couldn’t understand the code because it doesn’t seem to have a class “function” to call?).

Anyone know where I could find a PyTorch implementation to find the square root of a positive semi-definite matrix? Would greatly appreciate it. Many thanks in advance.

Hello Leo!

Perform the eigendecomposition of your matrix and then take the
square-root of your eigenvalues. (If any of your eigenvalues of your
semi-definite matrix show up as numerically negative, replace them
with zero.)

For more detail, see this post:

Best.

K. Frank

1 Like

Hi @KFrank, many thanks for your solution. The solution makes sense and it could definitely work for what I need to do. I am implementing a new type of classifier and one of the functions used in the square root of a positive semi-definite matrix.

One quick question though, if one can do this elegantly like what you have proposed, then why are people working on a PyTorch function for this which involves more complicated math/algorithms, for eg. in this Github link below especially when you see the more recent comments from the last couple of days?

Hello Leo!

First comment:

The code Yaroslav posted at the beginning of the github issue to which
you linked is basically what I suggested. He (properly) treats the null
space of the semi-definite matrix more carefully, and he (properly) uses
torch.symeig() rather than torch.eig().

Second comment:

More importantly, I’m not an expert, but I have no reason to believe that
eigendecomposition is the best algorithm for the root of a matrix. The
github issue discusses other approaches that could be faster and/or
numerically more stable or accurate.

The eigendecomposition contains, in a sense, more information than
the root, so it could well be more expensive to calculate. By way of
analogy, you can use eigendecomposition to calculate the inverse of
a matrix (take the reciprocals of the eigenvalues), but it is not the
preferred matrix-inverse algorithm.

Best.

K. Frank

For anyone interested in the same topic in the future, here is code snippets taken from PIQ library

  1. Computation using iterative Newton-Schulz algorithm
def _approximation_error(matrix: torch.Tensor, s_matrix: torch.Tensor) -> torch.Tensor:
    norm_of_matrix = torch.norm(matrix)
    error = matrix - torch.mm(s_matrix, s_matrix)
    error = torch.norm(error) / norm_of_matrix
    return error


def _sqrtm_newton_schulz(matrix: torch.Tensor, num_iters: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""
    Square root of matrix using Newton-Schulz Iterative method
    Source: https://github.com/msubhransu/matrix-sqrt/blob/master/matrix_sqrt.py
    Args:
        matrix: matrix or batch of matrices
        num_iters: Number of iteration of the method
    Returns:
        Square root of matrix
        Error
    """
    expected_num_dims = 2
    if matrix.dim() != expected_num_dims:
        raise ValueError(f'Input dimension equals {matrix.dim()}, expected {expected_num_dims}')

    if num_iters <= 0:
        raise ValueError(f'Number of iteration equals {num_iters}, expected greater than 0')

    dim = matrix.size(0)
    norm_of_matrix = matrix.norm(p='fro')
    Y = matrix.div(norm_of_matrix)
    I = torch.eye(dim, dim, requires_grad=False).to(matrix)
    Z = torch.eye(dim, dim, requires_grad=False).to(matrix)

    s_matrix = torch.empty_like(matrix)
    error = torch.empty(1).to(matrix)

    for _ in range(num_iters):
        T = 0.5 * (3.0 * I - Z.mm(Y))
        Y = Y.mm(T)
        Z = T.mm(Z)

        s_matrix = Y * torch.sqrt(norm_of_matrix)
        error = _approximation_error(matrix, s_matrix)
        if torch.isclose(error, torch.tensor([0.]).to(error), atol=1e-5):
            break

    return s_matrix, error
  1. Using eigen decomposition as discussed in this topic
def _matrix_pow(matrix: torch.Tensor, p: float) -> torch.Tensor:
    r"""
    Power of a matrix using Eigen Decomposition.
    Args:
        matrix: matrix
        p: power
    Returns:
        Power of a matrix
    """
    vals, vecs = torch.eig(matrix, eigenvectors=True)
    vals = torch.view_as_complex(vals.contiguous())
    vals_pow = vals.pow(p)
    vals_pow = torch.view_as_real(vals_pow)[:, 0]
    matrix_pow = torch.matmul(vecs, torch.matmul(torch.diag(vals_pow), torch.inverse(vecs)))
    return matrix_pow
2 Likes