Find indices of a tensor satisfying a condition

Does torch have a function that helps in finding the indices satisfying a condition?

For instance,

F = torch.randn(10)
b = torch.index(F <= 0)
print(b)

Thanks!

HI,

torch uses same convention as numpy such for finding values or indices of particular tensor regarding specific condition. torch.where and torch.nonzero

What you can do is to apply your condition and get a binary mask of indices that match the condition and find the indices using torch.nonzero().

import torch

a = torch.randn(10)
b = a <= 0
indices = b.nonzero()

bests

10 Likes

Thanks, @Nikronic. This really solves it.

How can you do it for 2-D MxM tensor?

The same way, i.e. .nonzero()

Hi @harpone
Is it possible to find the index of an N-D tensor? For example, can I find the index of torch.tensor([1, 2]) in torch.tensor([[1,2], [3, 4], [5, 6]])?

yeah your latter tensor here is 2D tensor and .nonzero() works fine with that. In general if you have an ND tensor, you will get an index tensor of shape [num_matches, N].