Consider I have a 3D tensor
M = torch.randn(3,2,2)
I’m constructing another 3D tensor A
with shape A.shape
is (3,2,2)
such that
A[i] = torch.diag(torch.diag(M[i])) # for all i in len(M)
Well, my M
here is on a computation graph (such as the output of neural network). So I’m looking for an efficient way of constructing A
while not take it off from computation graph (still can backpropagate) because code like
A= []
for i in range(M.shape[0]):
r.append(torch.diag(torch.diag(M[i])))
A = torch.stack(A)
It raises RuntimeError: one of the variables needed for gradient
any ideas ? Thanks