In a bid to get familiar with PyTorch syntax, I thought I’d try and see if I can use gradient descent to do SVD - but not just the standard SVD routine, instead multidimensional scaling (MDS) which requires SVD.

Essentially, I generated a random `n x n`

matrix `U`

, a random diagonal `n x n`

matrix `s`

, and a random `n x n`

matrix `Vh`

, just as a starting point. The goal is for `U s Vh.T`

to approximate some matrix `B`

ie. `U s Vh.T ~ SVD(B)`

.

I guess I’m running into two rookie pitfalls: (1) the loss is not updating after the first iteration (why?) (2) is it possible to “combine” two loss functions? Below, I “combine” the loss of the SVD approximation to the loss of the MDS approximation:

```
# load into pytorch
D = torch.rand(N, N).float()
H = torch.eye(N).float() - \
1/N * torch.ones(D.shape).float()
B = 0.5 * torch.matmul(torch.matmul(H, D**2), H)
pts = torch.rand(N, 3).float()
# declare constants
stop_loss = 1e-2
step_size = stop_loss / 3
# emulate simulate SVD
U = torch.autograd.Variable(torch.rand(N, N), requires_grad = True)
s = torch.autograd.Variable(torch.diag(torch.sort(torch.rand(N), descending = True).values), requires_grad = True)
Vh = torch.autograd.Variable(torch.rand(N, N), requires_grad = True)
# find embedding
embed = torch.matmul(torch.sqrt(s), Vh)
X_hat = embed.T[:, :3] # select the first 3 coordinates --> x, y, z
for i in range(100000):
# calculate loss1: how close is our SVD function?
delta1 = (torch.matmul(U.T, U) - torch.eye(N)) + \
(torch.matmul(V.T, V) - torch.eye(N)) + \
(torch.matmul(torch.matmul(U, s), Vh.T) - B)
L1 = torch.norm(delta1, p=2)
# calculate loss2: how close is our MDS approximation?
delta2 = torch.nan_to_num(X_hat - pts, nan=0.0)
L2 = torch.norm(delta2, p=2)
# Backprop
loss = L1 + L2
loss.backward()
# update
U.data -= step_size * U.grad.data
s.data -= step_size * s.grad.data
V.data -= step_size * V.grad.data
U.data.data.zero_()
s.data.data.zero_()
V.data.data.zero_()
if i % 1000 == 0:
print('Loss is %s at iteration %i' % (loss, i))
if abs(loss) < stop_loss:
break
```