I have a tensor like [[1,2],[3,4],[0,2],[2,0]] . I want to find along the columns the indices of the rows which have a particular value. For eg for 0, it should give me [2,3] as 2nd row has value 0 in first column and third for second column. I want to do it for a tensor. How do I go about this?
You could use binary masking and argmax afterwards:
mask = (tensor==0)
flattened_indices = mask.argmax()
Method 1:
x=torch.Tensor([[1,2],[3,4],[0,2],[2,0]])
np.transpose(np.argwhere(x==0))
returns
tensor([[2, 0],
[3, 1]])
1 Like