I have a question, is there any method to select a number in a tensor and return the index, such as:
t = [[1, 2],
[2, 3],
[5, 2]]
select(t, dim=0, 2) returns [1, 0, 1]
and select(t, dim=1, 2) returns [1, 0]
You can use a combination of .eq()
(or ==
) and nonzero
.
t = torch.Tensor(
[[1, 2],
[2, 3],
[5, 2]]
)
print((t == 2).nonzero())
returns
0 1
1 0
2 1
[torch.LongTensor of size 3x2]
3 Likes