Polar decomposition of matrices in pytorch

Does pytorch have a function that calculates the polar decomposition of a matrix like scipy.linalg.polar? The function torch.polar is an entirely different one. If not, why does pytorch choose to omit this feature?

1 Like

Hi Steve!

Not that I’m aware of.

Perhaps pytorch should provide polar decomposition. But it’s straightforward
enough to implement it in terms of pytorch’s singular-value decomposition.

Here is such a pytorch implementation in a script that verifies its result against
scipy and checks its autograd differentiability:

import torch
print (torch.__version__)

_ = torch.manual_seed (2023)

def polar_decomp (m):   # express polar decomposition in terms of singular-value decomposition
    U, S, Vh = torch.linalg.svd (m)
    u = U @ Vh
    p = Vh.T.conj() @ S.diag().to (dtype = m.dtype) @ Vh
    return  u, p

m = torch.randn (5, 5, dtype = torch.complex128)   # works for both real and complex m
print ('sample input m = ...')
print (m)

u, p = polar_decomp (m)

import scipy   # check against scipy
print (scipy.__version__)
uchk, pchk = scipy.linalg.polar (m)

print ('u (pytorch) = uchk (scipy)? :', torch.equal (u, torch.from_numpy (uchk)))
print ('p (pytorch) = pchk (scipy)? :', torch.equal (p, torch.from_numpy (pchk)))

# polar_decomp supports autograd because it is implemented with autograd-supporting pytorch operations
m.requires_grad = True
print ('torch.autograd.gradcheck (polar_decomp, m):', torch.autograd.gradcheck (polar_decomp, m))

And here is its output:

2.0.1
sample input m = ...
tensor([[ 0.7081+0.9242j, -0.3647+0.2722j,  0.5486-0.4323j,  0.4702+0.7164j,
         -0.2536+0.1605j],
        [-0.5186-0.4888j, -0.5726-0.3178j,  0.2299-0.3169j,  0.3628+0.4701j,
         -0.4271+0.0599j],
        [ 0.2517-0.1391j,  0.3351-0.4699j, -0.1361-0.3724j, -1.0979+0.0631j,
          1.3223-1.0719j],
        [-0.8938-0.0749j, -0.2021+0.4620j,  0.5332+0.7327j, -0.2232-0.0567j,
          0.5612+0.0902j],
        [ 0.3240+1.1715j, -0.3395+0.5913j,  0.1341+0.5250j, -0.4009-0.7055j,
         -0.2356-0.0124j]], dtype=torch.complex128)
1.11.1
u (pytorch) = uchk (scipy)? : True
p (pytorch) = pchk (scipy)? : True
torch.autograd.gradcheck (polar_decomp, m): True

Best.

K. Frank

2 Likes

Thanks Frank! I checked the source code of scipy.linalg.polar and found that it also use SVD under the hood. However, the polar decomposition of JAX has two implementations (SVD and another one called qdwh).