Consider I have a 3D tensor

`M = torch.randn(3,2,2)`

I’m constructing another 3D tensor `A`

with shape `A.shape`

is `(3,2,2)`

such that

```
A[i] = torch.diag(torch.diag(M[i])) # for all i in len(M)
```

Well, my `M`

here is on a computation graph (such as the output of neural network). So I’m looking for an efficient way of constructing `A`

while not take it off from computation graph (still can backpropagate) because code like

```
A= []
for i in range(M.shape[0]):
r.append(torch.diag(torch.diag(M[i])))
A = torch.stack(A)
```

It raises `RuntimeError: one of the variables needed for gradient`

any ideas ? Thanks