Find value in tensor

I have a tensor like [[1,2],[3,4],[0,2],[2,0]] . I want to find along the columns the indices of the rows which have a particular value. For eg for 0, it should give me [2,3] as 2nd row has value 0 in first column and third for second column. I want to do it for a tensor. How do I go about this?

You could use binary masking and argmax afterwards:

mask = (tensor==0)
flattened_indices = mask.argmax()
Method 1:
x=torch.Tensor([[1,2],[3,4],[0,2],[2,0]])
np.transpose(np.argwhere(x==0))

returns

tensor([[2, 0],
        [3, 1]])
1 Like