Matrix factorisation using gradient descent: performing singular value decomposition on a toy dataset

I’m trying to implement the singular value decomposition algorithm which factorises a matrix, M^(e,f), into three matrices U^(e,e), S^(e,f) and V^(f,f), such that M= U S V^T where U and V are orthonormal matrices and S a diagonal matrix by phrasing the problem as an optimisation one.

I create three 3 loss functions; one for the reconstruction error, one for the diagonal constraint on S, and one for the orthonormality constraint on U and V.

I understand there are algorithmic approaches to the problem of SVD and torch implements it too, but I’m trying to see if you can also solve it through optimisation.

Here is my code:


import torch 
import torch.distributions as d 
from functools import reduce
import math 

STEPS = 100
STEP_SIZE = 0.0000000001
product = lambda x,y : x*y
dims = (44,8)

probs = torch.Tensor([1] * 50)
probs = (1/probs.sum() ) * probs


def loss_non_orthormality(X):
    sim = X.T @ X
    # cols span orthonormal basis: sim ----> off-diagonal elements should be zero
    #                  ''              ----> diagonal elements should all be 1 
    sim_diag= torch.diagflat(sim.diag())  
    idt = torch.eye(X.shape[0])
    non_normal = sim_diag - idt 
    sim_off_diag = sim - sim_diag
    l1 =torch.mean(torch.square(non_normal))
    l2= torch.mean(torch.square(sim_off_diag))
    return l1 + l2  

def loss_off_diagonal(S):
    
    zeros = torch.zeros(size=(S.shape[0]-torch.diagflat(S.diag()).shape[0],S.shape[1] ))
    diag= torch.diagflat(S.diag())
    S_diag = torch.vstack([diag,zeros])
    non_zero = S - S_diag
    l = torch.mean(torch.square(non_zero))
    return l
    

def loss_reconstruction(m, U,S,V):
    mr = m - (U @ S @ V.T)
    l = torch.mean(torch.square(mr))    
    return l
 
def func_weighting(f1,f2,f3,f4,w1=0.06,w2=0.00,w3=0.00,w4=0.00):
    return w1*f1 + w2*f2 + w3*f3 + w4*f4

# Stopping conditions 
stop = lambda err, patience: True if all([e == err[-patience:][0] for e in err[-patience:]]) else False 
patience = 5
err = []

def fit():
    M = torch.Tensor([d.categorical.Categorical(probs=probs).sample() for _ in range(reduce(product,dims))]).reshape(dims)
    U = torch.Tensor([d.categorical.Categorical(probs=probs).sample() for _ in range(reduce(product,(dims[0],dims[0])))]).reshape((dims[0],dims[0]))
    S = torch.Tensor([d.categorical.Categorical(probs=probs).sample() for _ in range(reduce(product,dims))]).reshape(dims)
    V = torch.Tensor([d.categorical.Categorical(probs=probs).sample() for _ in range(reduce(product,(dims[1],dims[1])))]).reshape((dims[1],dims[1]))
    U.requires_grad = True 
    S.requires_grad = True 
    V.requires_grad = True
    
    for s in range(STEPS):
        loss = func_weighting(loss_reconstruction(M, U,S,V),
                                loss_off_diagonal(S),
                                loss_non_orthormality(U),
                                loss_non_orthormality(V))
        loss.backward()
        print(f"Step {s} loss: {loss.data}")
        U = U - STEP_SIZE * U.grad
        S = S - STEP_SIZE * S.grad
        V = V - STEP_SIZE * V.grad
        U.grad = torch.zeros_like(U)
        S.grad = torch.zeros_like(S)
        V.grad = torch.zeros_like(V)
        err.append(loss.data)
        if math.isnan(loss.data):
            print("fitting Diverged")
            raise RuntimeError("Loss became infinity and thereby NotANumber; Nan")
        if stop(err, patience) and s > patience + 1: 
            return (M, U, S, V)
        
M, U,S,V = fit()

The loss weighting function exists because when all equal, the fitting diverges. It’s only when the losses are weighted around a tenth of their actual value the script converges but even then, it does so in two steps.

Any chance anybody has any ideas on whether it’s feasible to do? Or if I’m just missing something obvious. I’ve tried different step sizes/ learning rates but they all diverge/converge in one/two steps.

Hi Akin!

Yes, this should be feasible (but, as you note, this wouldn’t be the “right” way to do it).

I don’t understand what you mean by “one/two steps.” You won’t be able to project
a randomly-initialized matrix onto an orthogonal matrix (your U or V matrix) in one
or two gradient-descent steps, so I don’t see how you could get “convergence” that
quickly.

The best way to explain what you mean would be to use torch.manual_seed()
to make things reproducible and post the results you get when you run your code.
It would also be useful to illustrate your issue with smaller values for dims, just
to make things easier to look at.

Some comments:

As it stands, you are initializing U and V – that are to be optimized into orthogonal
matrices – with integer matrix elements that are of order ten. This shouldn’t break
anything outright, but is rather perverse.

There is a lot of non-uniqueness in the singular-value decomposition. For starters,
the order of the singular values (your S) is arbitrary (unless fixed by imposing some
additional conditions). Furthermore, in the presence of degenerate singular values,
the singular vectors become non-unique. Again, this won’t break anything outright,
but you need to be aware of this when you look at the results of your decomposition.

You can certainly implement your own gradient descent, as you have done. But you
might consider using one of pytorch’s built-in optimizers. This would make it easier
to experiment with variations on the optimization algorithm.

When you start such an optimization with randomly-initialized values, you are likely
to start far from the optimum and have large gradients. To avoid divergence or other
bad behavior, you then need to use a small step size. I often find it helpful to train
for a little while – as determined by trial and error – with a small step size until the
optimization moves to a more “reasonable” point in configuration space and the
gradients become more moderate, and then switch to a larger step size so that the
optimization proceeds more rapidly.

Last, although I didn’t notice any errors, I haven’t looked at your code closely, so
there could, of course, be an outright bug somewhere.

Best.

K. Frank