Getting diagonal elements of matrices in batch

I have a tensor T with dimensions (batch_size, size_n, size_n)

T = 
    [[[0.9527, 0.8821],
      [0.8442, 0.6147]],
     [[0.0672, 0.5737],
      [0.1963, 0.4532]],
     [[0.0992, 0.3838],
      [0.4169, 0.0925]]]

and want to extract the diagonal of each matrix in that batch to get

diag_T = 
    [[0.9527, 0.6147],
     [0.0672, 0.4532],
     [0.0992, 0.0925]]

Is there some torch.diag() function that also works for batches?

Maybe not the best solution, but it is vectorized:

import torch

T = [[[0.9527, 0.8821],
      [0.8442, 0.6147]],
     [[0.0672, 0.5737],
      [0.1963, 0.4532]],
     [[0.0992, 0.3838],
      [0.4169, 0.0925]]]

T = torch.Tensor(T)

torch.einsum('ijj->ij', torch.stack(tuple(torch.ones(T.size(1)).diag() for i in range(T.size(0)))) * T)

Hi Samuel!

torch.diagonal() does what you want:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> T = torch.tensor (
...     [[[0.9527, 0.8821],
...     [0.8442, 0.6147]],
...     [[0.0672, 0.5737],
...     [0.1963, 0.4532]],
...     [[0.0992, 0.3838],
...     [0.4169, 0.0925]]]
... )
>>> torch.diagonal (T, dim1 = -2, dim2 = -1)
tensor([[0.9527, 0.6147],
        [0.0672, 0.4532],
        [0.0992, 0.0925]])

Best.

K. Frank

1 Like

Why do we use dim1=-2, dim2=-1 and not just dim1=1, dim2=2?

Do I miss something here?

Hi Samuel!

No particular reason – your T is a 3d tensor so the two versions are
equivalent. I suppose that using “negative” dimensions emphasizes
that we’re extracting the diagonals from the 2d matrices made up by
the last two dimensions of T (so that this version would generalize to a
hypothetical use case where T had multiple leading “batch” dimensions
such as T of shape [batch_size, channel_size, size_n, size_n]).

It’s really just stylistic – and not necessarily a better style.

Best.

K. Frank

1 Like