Numerical stability of torch.svd

I am working on a project where my model outputs a vector in R9, and then I project onto the nearest orthogonal matrix with the function below, and then have a mean squared error loss between the projected prediction and it’s ground truth value.

The docs mention a gradient stability issue when the singular values get too close to each other.
https://pytorch.org/docs/stable/generated/torch.linalg.svd.html

The singular values of a rotation matrix are (1,1,1), so as the predicted rotation is getting closer to being on SO(3), the singular values will become (1,1,1), and this would cause numerical issues, as per the docs.

Any comments? The only thing I can think of at this point, is to change the encoding of the rotation matrix so that I can project it onto the group without the svd. For example by using a unit quaternion… (I can just norm the 4-vector).

But there are some reasons I’d like the prediction to come directly as a vector in R9, so if I can get it working this way I’d like to.

def project_onto_so3(rotation_out_group):
    """
    Project a rotation matrix onto SO(3) by finding the closest rotation matrix in Frobenius norm.
    rotation_out_group.shape = (n_batch, 3, 3)
    """
    u, _, v = torch.svd(rotation_out_group, some=False, compute_uv=True)
    ones_3 = torch.ones(len(rotation_out_group), 3, device=rotation_out_group.device)  # TODO: precompute?
    ones_3[:, -1] = torch.linalg.det(u @ v)
    rotation_in_group = u @ torch.diag_embed(ones_3) @ v
    return rotation_in_group

Hi Geoff!

I think your concern is valid. (For a generic matrix that isn’t orthogonal,
your singular-value-decomposition approach is a fine way to orthogonalize
it, but, for the reason you discussed, if the matrix starts out being close to
orthogonal, you won’t be able to backpropagate through the orthogonalization
in a numerically stable way.)

A standard iterative orthogonalization algorithm that should meet your
needs is illustrated in this script:

import torch
torch.__version__

torch.random.manual_seed (2023)

m = torch.randn (3, 3)

def orthoA (a):
    u, s, vh = torch.linalg.svd (a)
    o = u @ vh
    return o

def orthoB (a):   # iterative method that avoids svd
    nIter = 5     # this is naive -- could use a convergence criterion
    o = a
    for  i in range (nIter):
        o = (o + torch.linalg.inv (o.T)) / 2
    
    return o

m = torch.randn (3, 3)

print ('m = ...')
print (m)

oA = orthoA (m)
oB = orthoB (m)

print ('oA = ...')
print (oA)

print ('torch.allclose (oA, oB, atol = 1.e-7) =', torch.allclose (oA, oB, atol = 1.e-7))

Here is its output:

m = ...
tensor([[-0.0794, -1.0846, -1.5421],
        [ 0.9377, -0.9787,  2.0930],
        [ 1.0231,  0.5431,  0.6514]])
oA = ...
tensor([[ 0.2961, -0.6718, -0.6790],
        [ 0.2847, -0.6165,  0.7341],
        [ 0.9118,  0.4106, -0.0088]])
torch.allclose (oA, oB, atol = 1.e-7) = True

Pytorch has no problem backpropagating through such an iterative
computation (although for algorithms that converge more slowly, many
iterations can be necessary and the computation graph can become large).

This algorithm fails if m is singular, but that is to be expected.

You can look at this iterative algorithm as generating the orthogonal Q
term of the QP polar decomposition of m.

As an aside, you should switch from torch.svd() to torch.linalg.svd()
in any other code you have as the former has been deprecated. Note that
the two versions differ slightly in syntax.

Best.

K. Frank