# Getting diagonal elements of matrices in batch

I have a tensor `T` with dimensions `(batch_size, size_n, size_n)`

``````T =
[[[0.9527, 0.8821],
[0.8442, 0.6147]],
[[0.0672, 0.5737],
[0.1963, 0.4532]],
[[0.0992, 0.3838],
[0.4169, 0.0925]]]
``````

and want to extract the diagonal of each matrix in that batch to get

``````diag_T =
[[0.9527, 0.6147],
[0.0672, 0.4532],
[0.0992, 0.0925]]
``````

Is there some `torch.diag()` function that also works for batches?

Maybe not the best solution, but it is vectorized:

``````import torch

T = [[[0.9527, 0.8821],
[0.8442, 0.6147]],
[[0.0672, 0.5737],
[0.1963, 0.4532]],
[[0.0992, 0.3838],
[0.4169, 0.0925]]]

T = torch.Tensor(T)

torch.einsum('ijj->ij', torch.stack(tuple(torch.ones(T.size(1)).diag() for i in range(T.size(0)))) * T)
``````

Hi Samuel!

torch.diagonal() does what you want:

``````>>> import torch
>>> torch.__version__
'1.9.0'
>>> T = torch.tensor (
...     [[[0.9527, 0.8821],
...     [0.8442, 0.6147]],
...     [[0.0672, 0.5737],
...     [0.1963, 0.4532]],
...     [[0.0992, 0.3838],
...     [0.4169, 0.0925]]]
... )
>>> torch.diagonal (T, dim1 = -2, dim2 = -1)
tensor([[0.9527, 0.6147],
[0.0672, 0.4532],
[0.0992, 0.0925]])
``````

Best.

K. Frank

1 Like

Why do we use `dim1=-2, dim2=-1` and not just `dim1=1, dim2=2`?

Do I miss something here?

Hi Samuel!

No particular reason – your `T` is a 3d tensor so the two versions are
equivalent. I suppose that using “negative” dimensions emphasizes
that we’re extracting the diagonals from the 2d matrices made up by
the last two dimensions of `T` (so that this version would generalize to a
hypothetical use case where `T` had multiple leading “batch” dimensions
such as `T` of shape `[batch_size, channel_size, size_n, size_n]`).

It’s really just stylistic – and not necessarily a better style.

Best.

K. Frank

1 Like