Given an array of co-variance matrix; Convert each of them to correlation matrix in an efficient way without remove them from computation graph?

Consider I have an array of matrix M.shape = (N,D,D) where D is the dimension of co-variance matrix and N is the number of co-variance matrices.

I have a method that convert a single co-variance matrix to a correlation matrix defined as

def convert_to_correlation(m):
    inv = torch.inverse(torch.diag(torch.diag(m)).sqrt())
    r = inv@m@inv
    return r

When I have a list of such co-variance matrix I work as

def convert_all_to_correlation(M):
    res = []
    for i in range(len(M)):
        m = convert_to_correlation(M[i])
        res.append(m)
    res = torch.stack(res)
    return res

There are two problems in above approach;

(1) The array could have large first dimension; hence very slow in iterative loop
(2) My M is an output of neural network (hence part of computation graph); I don’t want my approach in convert_to_all_correlation remove them from graph.

Any help is appreciated! Thanks

Hi,

I’m a bit confused by your convert_to_correlation. It does not use m.

For 1: all the operations your use support batching (.diagonal, .inverse and matmul), so you can write a convert_to_correlation that works on the whole batch at once.
For 2: we have an implementation of derivatives for inversion and all other functions you use. So the result will be differentiable :slight_smile: (You can check that res.requires_grad = True).

Sorry; I got typo there; I’ve fixed the typo.

I’ve also checked the torch.diag() method; it seems only work on 2D tensor ?

Yes, .diag() is an old api that only supports 2D Tensor. That is why I mentionned .diagonal() above that is the new api that supports batching (and diag of arbitrary dimensions which is super cool).

Thanks for the advice. However, I’m confused with .diagonal() usage here

for example:

>>> x
 tensor([[[1,  2],
          [ 3, 4]],

          [[5, 6],
          [7,  8]]])

  >>> torch.diagona(x)

tensor([[1., 7.],
        [2., 8.]])

while I’m expecting with [[1,4],[5,8]]. Am I missing anything?

You can check the doc for the function.
If you want 1 4, 5 8, that means that you want the diagonal for the dimensions 1, 2 instead of the default 0, 1.

perfect ! Thanks for the help!