I am working on a project where my model outputs a vector in R9, and then I project onto the nearest orthogonal matrix with the function below, and then have a mean squared error loss between the projected prediction and it’s ground truth value.

The docs mention a gradient stability issue when the singular values get too close to each other.

https://pytorch.org/docs/stable/generated/torch.linalg.svd.html

The singular values of a rotation matrix are (1,1,1), so as the predicted rotation is getting closer to being on SO(3), the singular values will become (1,1,1), and this would cause numerical issues, as per the docs.

Any comments? The only thing I can think of at this point, is to change the encoding of the rotation matrix so that I can project it onto the group without the svd. For example by using a unit quaternion… (I can just norm the 4-vector).

But there are some reasons I’d like the prediction to come directly as a vector in R9, so if I can get it working this way I’d like to.

```
def project_onto_so3(rotation_out_group):
"""
Project a rotation matrix onto SO(3) by finding the closest rotation matrix in Frobenius norm.
rotation_out_group.shape = (n_batch, 3, 3)
"""
u, _, v = torch.svd(rotation_out_group, some=False, compute_uv=True)
ones_3 = torch.ones(len(rotation_out_group), 3, device=rotation_out_group.device) # TODO: precompute?
ones_3[:, -1] = torch.linalg.det(u @ v)
rotation_in_group = u @ torch.diag_embed(ones_3) @ v
return rotation_in_group
```