Does torch have a function that helps in finding the indices satisfying a condition?
For instance,
F = torch.randn(10)
b = torch.index(F <= 0)
print(b)
Thanks!
Does torch have a function that helps in finding the indices satisfying a condition?
For instance,
F = torch.randn(10)
b = torch.index(F <= 0)
print(b)
Thanks!
HI,
torch uses same convention as numpy such for finding values or indices of particular tensor regarding specific condition. torch.where and torch.nonzero
What you can do is to apply your condition and get a binary mask of indices that match the condition and find the indices using torch.nonzero()
.
import torch
a = torch.randn(10)
b = a <= 0
indices = b.nonzero()
bests
Thanks, @Nikronic. This really solves it.
How can you do it for 2-D
MxM
tensor?
The same way, i.e. .nonzero()
Hi @harpone
Is it possible to find the index of an N-D tensor? For example, can I find the index of torch.tensor([1, 2])
in torch.tensor([[1,2], [3, 4], [5, 6]])
?
yeah your latter tensor here is 2D tensor and .nonzero() works fine with that. In general if you have an ND tensor, you will get an index tensor of shape [num_matches, N]
.
Hello. Is there any way without torch.nonzero or torch.where. These functions make a NoNZero node in onnx and this nonzero node cannot be converted to tensorrt