# Given an array of co-variance matrix; Convert each of them to correlation matrix in an efficient way without remove them from computation graph?

Consider I have an array of matrix `M.shape = (N,D,D)` where `D` is the dimension of co-variance matrix and `N` is the number of co-variance matrices.

I have a method that convert a single co-variance matrix to a correlation matrix defined as

``````def convert_to_correlation(m):
inv = torch.inverse(torch.diag(torch.diag(m)).sqrt())
r = inv@m@inv
return r
``````

When I have a list of such co-variance matrix I work as

``````def convert_all_to_correlation(M):
res = []
for i in range(len(M)):
m = convert_to_correlation(M[i])
res.append(m)
res = torch.stack(res)
return res
``````

There are two problems in above approach;

(1) The array could have large first dimension; hence very slow in iterative loop
(2) My `M` is an output of neural network (hence part of computation graph); I don’t want my approach in `convert_to_all_correlation` remove them from graph.

Any help is appreciated! Thanks

Hi,

I’m a bit confused by your `convert_to_correlation`. It does not use `m`.

For 1: all the operations your use support batching (.diagonal, .inverse and matmul), so you can write a convert_to_correlation that works on the whole batch at once.
For 2: we have an implementation of derivatives for inversion and all other functions you use. So the result will be differentiable (You can check that res.requires_grad = True).

Sorry; I got typo there; I’ve fixed the typo.

I’ve also checked the `torch.diag()` method; it seems only work on 2D tensor ?

Yes, `.diag()` is an old api that only supports 2D Tensor. That is why I mentionned `.diagonal()` above that is the new api that supports batching (and diag of arbitrary dimensions which is super cool).

Thanks for the advice. However, I’m confused with `.diagonal()` usage here

for example:

``````>>> x
tensor([[[1,  2],
[ 3, 4]],

[[5, 6],
[7,  8]]])

>>> torch.diagona(x)

tensor([[1., 7.],
[2., 8.]])

``````

while I’m expecting with `[[1,4],[5,8]]`. Am I missing anything?

You can check the doc for the function.
If you want 1 4, 5 8, that means that you want the diagonal for the dimensions 1, 2 instead of the default 0, 1.

perfect ! Thanks for the help!