Does torch have a function that helps in finding the indices satisfying a condition?
For instance,
F = torch.randn(10)
b = torch.index(F <= 0)
print(b)
Thanks!
Does torch have a function that helps in finding the indices satisfying a condition?
For instance,
F = torch.randn(10)
b = torch.index(F <= 0)
print(b)
Thanks!
HI,
torch uses same convention as numpy such for finding values or indices of particular tensor regarding specific condition. torch.where and torch.nonzero
What you can do is to apply your condition and get a binary mask of indices that match the condition and find the indices using torch.nonzero()
.
import torch
a = torch.randn(10)
b = a <= 0
indices = b.nonzero()
bests