Does pytorch have a function that calculates the polar decomposition of a matrix like scipy.linalg.polar
? The function torch.polar
is an entirely different one. If not, why does pytorch choose to omit this feature?
1 Like
Hi Steve!
Not that I’m aware of.
Perhaps pytorch should provide polar decomposition. But it’s straightforward
enough to implement it in terms of pytorch’s singular-value decomposition.
Here is such a pytorch implementation in a script that verifies its result against
scipy and checks its autograd differentiability:
import torch
print (torch.__version__)
_ = torch.manual_seed (2023)
def polar_decomp (m): # express polar decomposition in terms of singular-value decomposition
U, S, Vh = torch.linalg.svd (m)
u = U @ Vh
p = Vh.T.conj() @ S.diag().to (dtype = m.dtype) @ Vh
return u, p
m = torch.randn (5, 5, dtype = torch.complex128) # works for both real and complex m
print ('sample input m = ...')
print (m)
u, p = polar_decomp (m)
import scipy # check against scipy
print (scipy.__version__)
uchk, pchk = scipy.linalg.polar (m)
print ('u (pytorch) = uchk (scipy)? :', torch.equal (u, torch.from_numpy (uchk)))
print ('p (pytorch) = pchk (scipy)? :', torch.equal (p, torch.from_numpy (pchk)))
# polar_decomp supports autograd because it is implemented with autograd-supporting pytorch operations
m.requires_grad = True
print ('torch.autograd.gradcheck (polar_decomp, m):', torch.autograd.gradcheck (polar_decomp, m))
And here is its output:
2.0.1
sample input m = ...
tensor([[ 0.7081+0.9242j, -0.3647+0.2722j, 0.5486-0.4323j, 0.4702+0.7164j,
-0.2536+0.1605j],
[-0.5186-0.4888j, -0.5726-0.3178j, 0.2299-0.3169j, 0.3628+0.4701j,
-0.4271+0.0599j],
[ 0.2517-0.1391j, 0.3351-0.4699j, -0.1361-0.3724j, -1.0979+0.0631j,
1.3223-1.0719j],
[-0.8938-0.0749j, -0.2021+0.4620j, 0.5332+0.7327j, -0.2232-0.0567j,
0.5612+0.0902j],
[ 0.3240+1.1715j, -0.3395+0.5913j, 0.1341+0.5250j, -0.4009-0.7055j,
-0.2356-0.0124j]], dtype=torch.complex128)
1.11.1
u (pytorch) = uchk (scipy)? : True
p (pytorch) = pchk (scipy)? : True
torch.autograd.gradcheck (polar_decomp, m): True
Best.
K. Frank
2 Likes
Thanks Frank! I checked the source code of scipy.linalg.polar
and found that it also use SVD under the hood. However, the polar decomposition of JAX has two implementations (SVD and another one called qdwh).