How to do a complex index selection operation in a simple way

Here I have a N*4 matrix F with data from 0 to 1,I want to get a (N,) index tensor which requires
for every row in F , return the first value’s index which value >=R
if there is no such value in this row , return the max value’s index
Here’s an example(R=0.7)
[
[0.5 0.8 0.9 0.3],
[0.5 0.8 0.7 0.3],
[0.6 0.4 0.3 0.4],
]
will get
[1,1,0]

Thank you very much!

The following snippet solves your problem:

import torch

a = torch.tensor(
[
    [0.5, 0.8, 0.9, 0.3],
    [0.5, 0.8, 0.7, 0.3],
    [0.6, 0.4, 0.3, 0.4],
])
r = 0.7

torch.min(a.max(1)[1], ((a > r).cumsum(1) < 1).sum(1))

Note that (at least for torch 1.1) torch.zeros(n).argmax() == n - 1 and torch.zeros(n).cuda().argmax() == 0… which are logically both valid, but aren’t consistent. If your operations only use cpu tensors, you can do it a little simpler with the following:

torch.min(a.max(1)[1], (a > r).cumsum(1).argmin(1) + 1)
1 Like

Thanks for your time.

My pytorch is 1.2.0 and I find argmax always returns the last index in the above case(in CPU), so I didnot try to think about the consistency. I will follow your answer.

Again, thanks for your reply.

1 Like