Hi Séhane!
The short answer is use the power-iteration method to compute your
stationary state, w
. torch.sparse.mm (sparse, dense)
supports
backpropagation, so you will be able to compute your gradients without
materializing a dense matrix with the size of P
.
Beware that there is some nuance here, as discussed below.
Note that this is the eigenvector equation for w
(with eigenvalue one), so
calling the proposed scheme “power iteration” is appropriate.
This is also legitimate, except, as you’ve noted, torch.linalg.lstsq()
(and similar pytorch tools) don’t (yet?) support sparse matrices.
Further detail:
We have two schemes for computing your stationary vector: power iteration
and lstsq()
. We can understand both of these as functions that map a
matrix P
to its stationary vector, w
.
As seen in the script given below, these give the same result, but not the
same gradient. This is because these two schemes aren’t actually the
same function – they only give the same result when applied to a
legitimate markov-process transition matrix.
So you have, conceptually, a bunch of different functions that all give the
same result when applied to a legitimate transition matrix, but that produce
different gradients.
(As a simple example of this, consider f (x) = 1
, that is, the constant
function, and g (x) = x**2
. These two different functions have the same
value when applied to x = 1
, but have different derivatives. In this example
the “condition” that x = 1
is analogous to the condition that P
be a valid
transition matrix in your real problem.)
But which gradient is the “correct” one? In some sense that’s up to you
and the details of your use case.
One approach to addressing this problem is to map the input matrices
to these functions, which may or may not be transition matrices, to valid
transition matrices and then proceed to the stationary-state computation.
(I call this “normalization” in the script.) If we ensure that this normalization
map doesn’t change a matrix that is already a transition matrix, we will still
get the same result as from the unnormalized functions.
With this scheme, the normalized versions of power iteration and lstsq()
are the same function as one another and produce the same gradient.
(Note that pytorch’s support for sparse tensors is a moving target. It’s gotten
a lot better recently, but still has a lot of missing functionality and probably
still some bugs. I’ve run my example script with the pytorch’s latest stable
version, 2.2.2, and it wouldn’t surprise me if earlier versions don’t work the
same way. You will see some oddities in the example script that are required
due to missing sparse functionality.)
Consider the following script that implements these schemes, including
(normalized) power iteration for sparse P
:
import torch
print (torch.__version__)
_ = torch.manual_seed (2024)
n = 100
sparseMult = 3
powerCnt = 100
# create random, sparse, transition matrix
# this matrix need not be ergodic nor have unique
# stationary state, but will "probably" be okay
P = torch.rand (n, n) # random probability matrix (before masking and normalizing)
rowCounts = torch.poisson ((sparseMult - 1) * torch.ones (100)).long() + 1
m = torch.zeros (n, n) # to become mask matrix
for i, k in enumerate (rowCounts):
m[i][torch.multinomial (torch.ones (n), k)] = 1
P *= m # mask out most probabilities
P /= P.sum (dim = 1, keepdim = True) # normalize each row to sum to one
# check sparsity
print ('sparsity (non-zero fraction): ', (P != 0.0).sum() / P.numel())
# check valid transition matrix
print ('\nverify valid transition matrix:')
print ('any less than zero: ', (P < 0.0).any())
print ('any greater than one: ', (P > 1.0).any())
print ('max row-sum discrepancy: ', (P.sum (dim = 1) - 1.0).abs().max())
PSparse = P.to_sparse() # also create sparse version of P
# create random initial state vector -- u is a global variable
u = torch.rand (n)
u /= u.sum()
def statePower (T): # use "power iteration" to find stationary state
w = u.clone()
for i in range (powerCnt): w @= T
return w
def stateLstsq (T): # use lstsq to solve linear equation for stationary state
w = u.clone()
# form augmented matrix and "result" vector
A = T.T - torch.eye (n)
A = torch.cat ((A, torch.ones (1, n)), dim = 0)
b = torch.zeros (n + 1)
b[-1] = 1.0
w = torch.linalg.lstsq (A, b)[0]
return w
def tNorm (T): # "normalize" T to ensure valid transition matrix
eps = 1.e-12
t = T.abs()
t /= (t.sum (dim = 1, keepdim = True) + eps)
return t
def statePowerNorm (T): # normalize T before "power iteration"
return statePower (tNorm (T))
def stateLstsqNorm (T): # normalize T before "power iteration"
return stateLstsq (tNorm (T))
def statePowerSparse (T): # version of statePower for sparse T
w = u.clone().unsqueeze (-1)
TT = T.T
for i in range (powerCnt): w = torch.sparse.mm (TT, w)
return w.squeeze()
def tNormSparse (T): # "normalize" sparse T to ensure valid transition matrix
eps = 1.e-12
t = T.abs()
rowSum = torch.sparse.sum (t, dim = 1)
rowSumInv = 1.0 / (rowSum.to_dense() + eps)
# t *= rowSumInv # can't backpropagate through broadcast mul(), so do it "by hand"
expVals = rowSumInv[t.indices()[0]]
rsiExpand = torch.sparse_coo_tensor (t.indices(), expVals, t.shape).coalesce()
t = t * rsiExpand
return t
def statePowerNormSparse (T): # version of statePowerNorm for sparse T
return statePowerSparse (tNormSparse (T))
wPwr = statePower (P)
wLst = stateLstsq (P)
wPwrN = statePowerNorm (P)
wLstN = stateLstsqNorm (P)
print ('\ncheck agreement between state-vector algorithms:')
# computed state vectors agree
print ('max diff between wPwr and wLst: ', (wPwr - wLst).abs().max())
# "norm" versions of state vectors also agree
print ('max diff between wPwrN and wPwr: ', (wPwrN - wPwr).abs().max())
print ('max diff between wPwrN and wLstN: ', (wPwrN - wLstN).abs().max())
# compute gradients
P.requires_grad = True
# compute grad from power iteration
wPwr = statePower (P)
(wPwr**2).sum().backward() # call backward on some scalar function of wPwr
gradPower = P.grad.clone()
# compute grad from lstsq solution
P.grad = None
wLst = stateLstsq (P)
(wLst**2).sum().backward() # call backward on the same scalar function of wLst
gradLstsq = P.grad.clone()
print ('\ncheck disagreement / agreement for gradients of algorithms:')
# but gradients from the two versions disagree
print ('max diff between grads: ', (gradPower - gradLstsq).abs().max())
# compute gradients using "norm" versions
# compute powerNorm grad
P.grad = None
wPwrN = statePowerNorm (P)
(wPwrN**2).sum().backward() # backward on wPwrN
gradPowerNorm = P.grad.clone()
# compute lstsqNorm grad
P.grad = None
wLstN = stateLstsqNorm (P)
(wLstN**2).sum().backward() # backward on wLstN
gradLstsqNorm = P.grad.clone()
# however the gradients from the "norm" versions agree
print ('max diff between "norm" grads: ', (gradPowerNorm - gradLstsqNorm).abs().max())
# compute state vector and gradient using sparse power "norm" version
PSparse.requires_grad = True
wSprN = statePowerNormSparse (PSparse)
(wSprN**2).sum().backward()
gradSparseNorm = PSparse.grad.clone().coalesce()
print ('\ncheck state vector and gradient for sparse algorithm:')
# sparse-version of state vector agrees
print ('max diff between wSprN and wLstN: ', (wSprN - wLstN).abs().max())
# sparse gradient agrees with corresponding elements of dense gradient (using "norm" versions)
print ('max diff for sparse "norm" grad: ', (gradSparseNorm.values() - gradLstsqNorm[PSparse.indices().unbind()]).abs().max())
Here is the script’s output:
2.2.2
sparsity (non-zero fraction): tensor(0.0275)
verify valid transition matrix:
any less than zero: tensor(False)
any greater than one: tensor(False)
max row-sum discrepancy: tensor(1.1921e-07)
check agreement between state-vector algorithms:
max diff between wPwr and wLst: tensor(7.4506e-08)
max diff between wPwrN and wPwr: tensor(1.8626e-08)
max diff between wPwrN and wLstN: tensor(7.5554e-08)
check disagreement / agreement for gradients of algorithms:
max diff between grads: tensor(0.2146)
max diff between "norm" grads: tensor(5.9372e-08)
check state vector and gradient for sparse algorithm:
max diff between wSprN and wLstN: tensor(8.5808e-08, grad_fn=<MaxBackward1>)
max diff for sparse "norm" grad: tensor(8.4285e-08)
Best.
K. Frank