What might be a possible analog for https://www.tensorflow.org/api_docs/python/tf/linalg/diag_part in pytorch?
I didnot manage to find something in docs
What might be a possible analog for https://www.tensorflow.org/api_docs/python/tf/linalg/diag_part in pytorch?
I didnot manage to find something in docs
I donโt think that there is an exact equivalent. But here is a workaround:
PS: this example is the same exposed in the link you provided.
import torch
x = torch.Tensor([
[[1, 0, 0, 0],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]],
[[5, 0, 0, 0],
[0, 6, 0, 0],
[0, 0, 7, 0],
[0, 0, 0, 8]]
])
y = torch.stack(tuple(t.diag() for t in torch.unbind(x,0)))
I think now torch.diagonal
serves for same purpose.
import torch
a = np.array([[[1, 2, 3, 4], # Input shape: (2, 3, 4)
[5, 6, 7, 8],
[9, 8, 7, 6]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8]]])
a = torch.from_numpy(a)
# a.size()
torch.diagonal(a, offset=0, dim1=-2, dim2=-1)
# torch.diagonal(b_tensor, offset=0, dim1=1, dim2=2) # this return same results, set dim
return:
tensor([[1, 6, 7],
[5, 2, 7]])
More details can be found in: torch โ PyTorch master documentation