Given a 3D tensor M; how to construct an array of diagonal matrix in which elements from every diag(M[i])?

Consider I have a 3D tensor

M = torch.randn(3,2,2)

I’m constructing another 3D tensor A with shape A.shape is (3,2,2) such that

A[i] = torch.diag(torch.diag(M[i])) # for all i in len(M)

Well, my M here is on a computation graph (such as the output of neural network). So I’m looking for an efficient way of constructing A while not take it off from computation graph (still can backpropagate) because code like

A= []

for i in range(M.shape[0]):
A = torch.stack(A)

It raises RuntimeError: one of the variables needed for gradient

any ideas ? Thanks

Didn’t see an answer here so:

You can specify the two dimensions over which you want to take the diagonal with torch.diagonal.

assert M.ndim == 3
A = M * torch.eye(*M.shape[-2:]).repeat(M.shape[0], 1, 1)