# Polar decomposition of matrices in pytorch

Does pytorch have a function that calculates the polar decomposition of a matrix like `scipy.linalg.polar`? The function `torch.polar` is an entirely different one. If not, why does pytorch choose to omit this feature?

1 Like

Hi Steve!

Not that I’m aware of.

Perhaps pytorch should provide polar decomposition. But it’s straightforward
enough to implement it in terms of pytorch’s singular-value decomposition.

Here is such a pytorch implementation in a script that verifies its result against
scipy and checks its autograd differentiability:

``````import torch
print (torch.__version__)

_ = torch.manual_seed (2023)

def polar_decomp (m):   # express polar decomposition in terms of singular-value decomposition
U, S, Vh = torch.linalg.svd (m)
u = U @ Vh
p = Vh.T.conj() @ S.diag().to (dtype = m.dtype) @ Vh
return  u, p

m = torch.randn (5, 5, dtype = torch.complex128)   # works for both real and complex m
print ('sample input m = ...')
print (m)

u, p = polar_decomp (m)

import scipy   # check against scipy
print (scipy.__version__)
uchk, pchk = scipy.linalg.polar (m)

print ('u (pytorch) = uchk (scipy)? :', torch.equal (u, torch.from_numpy (uchk)))
print ('p (pytorch) = pchk (scipy)? :', torch.equal (p, torch.from_numpy (pchk)))

# polar_decomp supports autograd because it is implemented with autograd-supporting pytorch operations
``````

And here is its output:

``````2.0.1
sample input m = ...
tensor([[ 0.7081+0.9242j, -0.3647+0.2722j,  0.5486-0.4323j,  0.4702+0.7164j,
-0.2536+0.1605j],
[-0.5186-0.4888j, -0.5726-0.3178j,  0.2299-0.3169j,  0.3628+0.4701j,
-0.4271+0.0599j],
[ 0.2517-0.1391j,  0.3351-0.4699j, -0.1361-0.3724j, -1.0979+0.0631j,
1.3223-1.0719j],
[-0.8938-0.0749j, -0.2021+0.4620j,  0.5332+0.7327j, -0.2232-0.0567j,
0.5612+0.0902j],
[ 0.3240+1.1715j, -0.3395+0.5913j,  0.1341+0.5250j, -0.4009-0.7055j,
-0.2356-0.0124j]], dtype=torch.complex128)
1.11.1
u (pytorch) = uchk (scipy)? : True
p (pytorch) = pchk (scipy)? : True
Thanks Frank! I checked the source code of `scipy.linalg.polar` and found that it also use SVD under the hood. However, the polar decomposition of JAX has two implementations (SVD and another one called qdwh).