Find indices of a tensor satisfying a condition

Does torch have a function that helps in finding the indices satisfying a condition?

For instance,

F = torch.randn(10)
b = torch.index(F <= 0)
print(b)

Thanks!

HI,

torch uses same convention as numpy such for finding values or indices of particular tensor regarding specific condition. torch.where and torch.nonzero

What you can do is to apply your condition and get a binary mask of indices that match the condition and find the indices using torch.nonzero().

import torch

a = torch.randn(10)
b = a <= 0
indices = b.nonzero()

bests

8 Likes

Thanks, @Nikronic. This really solves it.