Linear solver for sparse matrices

Hello,
I’d like to solve a linear system Ax=b where A is not square, but I know that there is exactly one solution.

The matrix A is represented as a sparse matrix that cannot be densified because it is too large. I also want the autograd to work on A.

The problem is that the only solutions I found so far are either computing a dense representation of A (which doesn’t work since A is too big) or using scipy (which is not compatible with autograd).

Is there a solution availiable on pytorch ?
Thank you !

I haven’t got experience working with sparse matrices in particular, but a quick peruse of the docs shows that matrix multiplication operations are supported. With that in mind, if your matrix is very tall or very wide, but the other dimension is reasonable, you could try to explicitly form the Moore Penrose Inverse. This involves taking (A^T A)^(-1). If you have a very tall matrix with a reasonable number of columns, A^T A may be a reasonable shape, and you could convert that intermediate step into a dense matrix to find either the explicit inverse (not usually recommended due to numerical stability issues) or any other linear solver supporting dense matrices.

Both dimensions of the matrix are too large. The thing is that I’m shure that solving a linear system on a very big sparse matrix is differenciable and not computationally very expensive (using the algorithm of Gauss for example). That’s why I’d expect pytorch to have such a solver, but I didn’t find any.

Hi Séhane!

Because A is not square, your system of equations must be either
under-determined or over-determined. If it is under-determined, it
has, by definition, more than one solution.

If it is over-determined, you will, in general, have no solutions, so A
must have some special structure that permits it to have a solution.

Could you explain that structure, as well as why it leads to exactly
one
solution?

Let’s say that you do compute x somehow. Do you need x to be
differentiable with respect to A or to b or to both? (If you only need
differentiability with respect to b, you could quite easily write a
custom autograd function that uses scipy’s sparse solver for both
the forward and backward passes.)

Best.

K. Frank

Hello,
So I want to differenciate regarding some parameters in A.
The setting is the following.
I have a matrix P describing a markov decision process ( which means that each line of P is a probability distribution, so it has positive values that sums up to one).
I want to find the stationnary state, which is defined as the limit of (P^n) when n tends to infinity. It is caracterized as the only matrix w such that :

  • wP=w
  • w is a probability distribution, so it has positive values that sums up to one

To solve it I decided to define a matrix X as P-Id and a line of ones and Y is a row of zeros with a one at the bottom.

To make this work you additionally need to transpose the matrices correcty. I implemented it and it worked with dense matrices. The thing is that P is sparse, so X is spparse too, and I want to scale the problem. It is only possible using sparse matrices.

Thank you for your answer !

Hi Séhane!

The short answer is use the power-iteration method to compute your
stationary state, w. torch.sparse.mm (sparse, dense) supports
backpropagation, so you will be able to compute your gradients without
materializing a dense matrix with the size of P.

Beware that there is some nuance here, as discussed below.

Note that this is the eigenvector equation for w (with eigenvalue one), so
calling the proposed scheme “power iteration” is appropriate.

This is also legitimate, except, as you’ve noted, torch.linalg.lstsq()
(and similar pytorch tools) don’t (yet?) support sparse matrices.

Further detail:

We have two schemes for computing your stationary vector: power iteration
and lstsq(). We can understand both of these as functions that map a
matrix P to its stationary vector, w.

As seen in the script given below, these give the same result, but not the
same gradient. This is because these two schemes aren’t actually the
same function – they only give the same result when applied to a
legitimate markov-process transition matrix.

So you have, conceptually, a bunch of different functions that all give the
same result when applied to a legitimate transition matrix, but that produce
different gradients.

(As a simple example of this, consider f (x) = 1, that is, the constant
function, and g (x) = x**2. These two different functions have the same
value when applied to x = 1, but have different derivatives. In this example
the “condition” that x = 1 is analogous to the condition that P be a valid
transition matrix in your real problem.)

But which gradient is the “correct” one? In some sense that’s up to you
and the details of your use case.

One approach to addressing this problem is to map the input matrices
to these functions, which may or may not be transition matrices, to valid
transition matrices and then proceed to the stationary-state computation.
(I call this “normalization” in the script.) If we ensure that this normalization
map doesn’t change a matrix that is already a transition matrix, we will still
get the same result as from the unnormalized functions.

With this scheme, the normalized versions of power iteration and lstsq()
are the same function as one another and produce the same gradient.

(Note that pytorch’s support for sparse tensors is a moving target. It’s gotten
a lot better recently, but still has a lot of missing functionality and probably
still some bugs. I’ve run my example script with the pytorch’s latest stable
version, 2.2.2, and it wouldn’t surprise me if earlier versions don’t work the
same way. You will see some oddities in the example script that are required
due to missing sparse functionality.)

Consider the following script that implements these schemes, including
(normalized) power iteration for sparse P:

import torch
print (torch.__version__)

_ = torch.manual_seed (2024)

n = 100
sparseMult = 3
powerCnt = 100

# create random, sparse, transition matrix
#   this matrix need not be ergodic nor have unique
#   stationary state, but will "probably" be okay

P = torch.rand (n, n)                  # random probability matrix (before masking and normalizing)
rowCounts = torch.poisson ((sparseMult - 1) * torch.ones (100)).long() + 1
m = torch.zeros (n, n)                 # to become mask matrix 
for  i, k in enumerate (rowCounts):
    m[i][torch.multinomial (torch.ones (n), k)] = 1

P *= m                                 # mask out most probabilities
P /= P.sum (dim = 1, keepdim = True)   # normalize each row to sum to one

# check sparsity
print ('sparsity (non-zero fraction):     ', (P != 0.0).sum() / P.numel())

# check valid transition matrix
print ('\nverify valid transition matrix:')
print ('any less than zero:               ', (P < 0.0).any())
print ('any greater than one:             ', (P > 1.0).any())
print ('max row-sum discrepancy:          ', (P.sum (dim = 1) - 1.0).abs().max())

PSparse = P.to_sparse()                # also create sparse version of P

# create random initial state vector -- u is a global variable
u = torch.rand (n)
u /= u.sum()

def statePower (T):                    # use "power iteration" to find stationary state
    w = u.clone()
    for  i in range (powerCnt):  w @= T
    return w

def stateLstsq (T):                    # use lstsq to solve linear equation for stationary state
    w = u.clone()
    # form augmented matrix and "result" vector
    A = T.T - torch.eye (n)
    A = torch.cat ((A, torch.ones (1, n)), dim = 0)
    b = torch.zeros (n + 1)
    b[-1] = 1.0
    w = torch.linalg.lstsq (A, b)[0]
    return w

def tNorm (T):                         # "normalize" T to ensure valid transition matrix
    eps = 1.e-12
    t = T.abs()
    t /= (t.sum (dim = 1, keepdim = True) + eps)
    return t

def statePowerNorm (T):                # normalize T before "power iteration"
    return statePower (tNorm (T))

def stateLstsqNorm (T):                # normalize T before "power iteration"
    return stateLstsq (tNorm (T))

def statePowerSparse (T):              # version of statePower for sparse T
    w = u.clone().unsqueeze (-1)
    TT = T.T
    for  i in range (powerCnt):  w = torch.sparse.mm (TT, w)
    return w.squeeze()

def tNormSparse (T):                   # "normalize" sparse T to ensure valid transition matrix
    eps = 1.e-12
    t = T.abs()
    rowSum = torch.sparse.sum (t, dim = 1)
    rowSumInv = 1.0 / (rowSum.to_dense() + eps)
    # t *= rowSumInv                   # can't backpropagate through broadcast mul(), so do it "by hand"
    expVals = rowSumInv[t.indices()[0]]
    rsiExpand = torch.sparse_coo_tensor (t.indices(), expVals, t.shape).coalesce()
    t = t * rsiExpand
    return t

def statePowerNormSparse (T):          # version of statePowerNorm for sparse T
    return statePowerSparse (tNormSparse (T))

wPwr = statePower (P)
wLst = stateLstsq (P)
wPwrN = statePowerNorm (P)
wLstN = stateLstsqNorm (P)

print ('\ncheck agreement between state-vector algorithms:')

# computed state vectors agree
print ('max diff between wPwr and wLst:   ', (wPwr - wLst).abs().max())

# "norm" versions of state vectors also agree
print ('max diff between wPwrN and wPwr:  ', (wPwrN - wPwr).abs().max())
print ('max diff between wPwrN and wLstN: ', (wPwrN - wLstN).abs().max())

# compute gradients
P.requires_grad = True

# compute grad from power iteration
wPwr = statePower (P)
(wPwr**2).sum().backward()             # call backward on some scalar function of wPwr
gradPower = P.grad.clone()

# compute grad from lstsq solution
P.grad = None
wLst = stateLstsq (P)
(wLst**2).sum().backward()             # call backward on the same scalar function of wLst
gradLstsq = P.grad.clone()

print ('\ncheck disagreement / agreement for gradients of algorithms:')

# but gradients from the two versions disagree
print ('max diff between grads:           ', (gradPower - gradLstsq).abs().max())

# compute gradients using "norm" versions

# compute powerNorm grad
P.grad = None
wPwrN = statePowerNorm (P)
(wPwrN**2).sum().backward()            # backward on wPwrN
gradPowerNorm = P.grad.clone()

# compute lstsqNorm grad
P.grad = None
wLstN = stateLstsqNorm (P)
(wLstN**2).sum().backward()            # backward on wLstN
gradLstsqNorm = P.grad.clone()

# however the gradients from the "norm" versions agree
print ('max diff between "norm" grads:    ', (gradPowerNorm - gradLstsqNorm).abs().max())

# compute state vector and gradient using sparse power "norm" version

PSparse.requires_grad = True

wSprN = statePowerNormSparse (PSparse)
(wSprN**2).sum().backward()
gradSparseNorm = PSparse.grad.clone().coalesce()

print ('\ncheck state vector and gradient for sparse algorithm:')

# sparse-version of state vector agrees
print ('max diff between wSprN and wLstN: ', (wSprN - wLstN).abs().max())

# sparse gradient agrees with corresponding elements of dense gradient (using "norm" versions)
print ('max diff for sparse "norm" grad:  ', (gradSparseNorm.values() - gradLstsqNorm[PSparse.indices().unbind()]).abs().max())

Here is the script’s output:

2.2.2
sparsity (non-zero fraction):      tensor(0.0275)

verify valid transition matrix:
any less than zero:                tensor(False)
any greater than one:              tensor(False)
max row-sum discrepancy:           tensor(1.1921e-07)

check agreement between state-vector algorithms:
max diff between wPwr and wLst:    tensor(7.4506e-08)
max diff between wPwrN and wPwr:   tensor(1.8626e-08)
max diff between wPwrN and wLstN:  tensor(7.5554e-08)

check disagreement / agreement for gradients of algorithms:
max diff between grads:            tensor(0.2146)
max diff between "norm" grads:     tensor(5.9372e-08)

check state vector and gradient for sparse algorithm:
max diff between wSprN and wLstN:  tensor(8.5808e-08, grad_fn=<MaxBackward1>)
max diff for sparse "norm" grad:   tensor(8.4285e-08)

Best.

K. Frank