Finding the first occurrence of an element in a 2D tensor

How do I find only the first occurrence of an element in each row of a 2D tensor?

For instance, if the input is:

t = torch.Tensor([[10, 1, 2, 3, 4, -10, -10,-10],
[10, 4, 2, 3, -10, -10, -10, -10], 
[10, 1, 2, 3, 4, 5, 6, -10]])

And I wish to find the first occurrence of value=-10 in every row of the tensor. Therefore, I would like the output to be:

tensor([5, 4, 7])

How would I go about this without using for loops and using torch-native operations?

I’m looking for something similar, did you find a solution to this?

Hi Sami (and Vivek)!

((t == -10).cumsum (1).cumsum (1) == 1).argsort (1)[:, -1]
should work.

Best.

K. Frank

1 Like