Stochastic matrix parameterization

Hi all, I’m currently implementing a neural net that output a tensor of shape (batch_size, N, N) where NxN matrices have their column sum and row sum =1. Can anyone give me a help, please

Hi Leonard!

If you want to train your network to produce such matrices, you can
add to your loss function a term that penalizes violations of your desired
constraints. The matrices usually won’t satisfy the constraints exactly, but
if the penalty term is weighted heavily enough (and you train enough),
they should come close.

You could use:

penalty = ((x.sum (dim = 2) - 1)**2).sum() + ((x.sum (dim = 1) - 1)**2).sum()

If you want to coerce the output of your network to satisfy the constraints,
you can simply modify the matrices after they have been produced. Simply
forcing the rows and columns to sum to one is straightforward.

I’m not sure of your use case. If you want the rows and columns of your
matrices to be proper row-wise and column-wise (discrete) probability
distributions – that is, not only sum to one, but have the individual matrix
elements lie between zero and one – that is more involved, and I don’t
know of a clean, symmetrical approach (other than employing iteration).

The obvious straightforward method for coercing the row and column
sums, and a clunky, ad hoc approach to coercing row- and column-wise
probability distributions are illustrated in the following script:

import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

def fix_sums (x):    # treats rows and columns symmetrically
    N = x.shape[1]   # assumes that x is of shape [batch_size, N, N]
    x = x + (1 - x.sum (dim = 2).unsqueeze (2)) / N   # row sums
    x = x + (1 - x.sum (dim = 1).unsqueeze (1)) / N   # column sums
    return x

t = torch.randn (3, 4, 4)
u = fix_sums (t)
print ('fix_sums:')
print ('check rows:    ', torch.allclose (u.sum (dim = 2), torch.ones (1)))
print ('check columns: ', torch.allclose (u.sum (dim = 1), torch.ones (1)))

def fix_probs (x):   # does not treat rows and columns symmetrically
    N = x.shape[1]   # assumes that x is of shape [batch_size, N, N]
    b1 = torch.tensor (1.0 / N)   # ad hoc buffer to maintain  0 <= x <= 1
    b2 = torch.tensor (1.5 / N)   #    when column-sums-correction is applied
    rmin = x.min (dim = 2)[0]
    rlo = torch.max (torch.min (rmin, b2), b1)
    rmax = x.max (dim = 2)[0]
    rhi = torch.min (torch.max (rmax, 1 - b2), 1 - b1)
    m = (rhi - rlo) / (rmax - rmin)
    d = (rmax * rlo - rmin * rhi) / (rhi - rlo)
    x = m.unsqueeze (2) * (x + d.unsqueeze (2))   # 0 <= x <= 1
    x = x / x.sum (dim = 2).unsqueeze (2)             # row sums
    x = x + (1 - x.sum (dim = 1).unsqueeze (1)) / N   # column sums
    return x

v = fix_probs (t)
print ('fix_probs:')
print ('check >= 0:    ', torch.all (v >= 0))
print ('check <= 1:    ', torch.all (v <= 1))
print ('check rows:    ', torch.allclose (v.sum (dim = 2), torch.ones (1)))
print ('check columns: ', torch.allclose (v.sum (dim = 1), torch.ones (1)))

Here is its output:

1.10.2
fix_sums:
check rows:     True
check columns:  True
fix_probs:
check >= 0:     tensor(True)
check <= 1:     tensor(True)
check rows:     True
check columns:  True

Best.

K. Frank

Oh, thank you so much @KFrank, this helps me a lot. Btw, is this algorithm to derive a doubly stochastic matrix well known, or you just come up with this method right away brilliantly?

I’ve tried the Sinkhorn-Knopp algorithm yet it requires an extra memory footprint GitHub - lucidrains/sinkhorn-transformer: Sinkhorn Transformer - Practical implementation of Sparse Sinkhorn Attention

However, I got this assertion error in training when using your method @KFrank

assert torch.all (x >= 0), f"{x}"
It seems x>=0 is not alway satisfied

Hi Leonard!

Yes, the algorithm I posted is ad hoc.

    b1 = torch.tensor (1.0 / N)   # ad hoc buffer to maintain  0 <= x <= 1
    b2 = torch.tensor (1.5 / N)   #    when column-sums-correction is applied

b1 and b2 are “fix-up” parameters that define how much to shrink the
rows inside of [0.0, 1.0] so that they don’t go back outside of that
range when the columns are subsequently modified to sum to one.

They may well need to depend on N (the size of x) and possibly on
the range of the elements in x.

I expect that if the assert fails, you could, as a fix-up, simply apply the
fix_probs() function a second (or third, or fourth …) time. But that
certainly would be clunky.

Best.

K. Frank

Hi Leonard!

I don’t know of such an algorithm – what I posted is ad hoc.

Part of the problem is that the goal of the algorithm is not well defined.

Consider the one-dimensional case of producing a “singly-stochastic
vector:”

Let’s say you start with a vector of arbitrary values and want to map it
to a singly-stochastic vector. You could use the:

    x = m.unsqueeze (2) * (x + d.unsqueeze (2))   # 0 <= x <= 1
    x = x / x.sum (dim = 2).unsqueeze (2)             # row sums

piece of the fix_probs() function (without the b1 / b2 “buffer”). This has
the advantage that if your vector already does represent a valid (discrete)
probability distribution, it won’t be changed.

But, instead, you could apply softmax() to your vector – the standard
way to convert a vector of raw-score logits to probabilities. This has
the disadvantage that applying softmax() to a valid set of probabilities
does not leave them unchanged. Which is right? That really depends
on what the meaning of the original input vector is.

To illustrate this issue with a silly exaggeration, you can convert your
input matrix to a doubly-stochastic matrix by replacing all of its elements
with 1 / N. Of course, the result is now independent of the input matrix.

But if you reject this 1 / N proposal, you have to provide some concrete
criterion for how the output matrix is supposed to depend on the input.
Without such a criterion, the 1 / N proposal would be a legitimate solution.

Best.

K. Frank