Does pytorch have a function that calculates the polar decomposition of a matrix like `scipy.linalg.polar`

? The function `torch.polar`

is an entirely different one. If not, why does pytorch choose to omit this feature?

1 Like

Hi Steve!

Not that I’m aware of.

Perhaps pytorch should provide polar decomposition. But it’s straightforward

enough to implement it in terms of pytorch’s singular-value decomposition.

Here is such a pytorch implementation in a script that verifies its result against

scipy and checks its autograd differentiability:

```
import torch
print (torch.__version__)
_ = torch.manual_seed (2023)
def polar_decomp (m): # express polar decomposition in terms of singular-value decomposition
U, S, Vh = torch.linalg.svd (m)
u = U @ Vh
p = Vh.T.conj() @ S.diag().to (dtype = m.dtype) @ Vh
return u, p
m = torch.randn (5, 5, dtype = torch.complex128) # works for both real and complex m
print ('sample input m = ...')
print (m)
u, p = polar_decomp (m)
import scipy # check against scipy
print (scipy.__version__)
uchk, pchk = scipy.linalg.polar (m)
print ('u (pytorch) = uchk (scipy)? :', torch.equal (u, torch.from_numpy (uchk)))
print ('p (pytorch) = pchk (scipy)? :', torch.equal (p, torch.from_numpy (pchk)))
# polar_decomp supports autograd because it is implemented with autograd-supporting pytorch operations
m.requires_grad = True
print ('torch.autograd.gradcheck (polar_decomp, m):', torch.autograd.gradcheck (polar_decomp, m))
```

And here is its output:

```
2.0.1
sample input m = ...
tensor([[ 0.7081+0.9242j, -0.3647+0.2722j, 0.5486-0.4323j, 0.4702+0.7164j,
-0.2536+0.1605j],
[-0.5186-0.4888j, -0.5726-0.3178j, 0.2299-0.3169j, 0.3628+0.4701j,
-0.4271+0.0599j],
[ 0.2517-0.1391j, 0.3351-0.4699j, -0.1361-0.3724j, -1.0979+0.0631j,
1.3223-1.0719j],
[-0.8938-0.0749j, -0.2021+0.4620j, 0.5332+0.7327j, -0.2232-0.0567j,
0.5612+0.0902j],
[ 0.3240+1.1715j, -0.3395+0.5913j, 0.1341+0.5250j, -0.4009-0.7055j,
-0.2356-0.0124j]], dtype=torch.complex128)
1.11.1
u (pytorch) = uchk (scipy)? : True
p (pytorch) = pchk (scipy)? : True
torch.autograd.gradcheck (polar_decomp, m): True
```

Best.

K. Frank

2 Likes

Thanks Frank! I checked the source code of `scipy.linalg.polar`

and found that it also use SVD under the hood. However, the polar decomposition of JAX has two implementations (SVD and another one called qdwh).