Here I have a N*4 matrix F with data from 0 to 1,I want to get a (N,) index tensor which requires
for every row in F , return the first value’s index which value >=R
if there is no such value in this row , return the max value’s index
Here’s an example(R=0.7)
[
[0.5 0.8 0.9 0.3],
[0.5 0.8 0.7 0.3],
[0.6 0.4 0.3 0.4],
]
will get
[1,1,0]
Note that (at least for torch 1.1) torch.zeros(n).argmax() == n - 1 and torch.zeros(n).cuda().argmax() == 0… which are logically both valid, but aren’t consistent. If your operations only use cpu tensors, you can do it a little simpler with the following:
torch.min(a.max(1)[1], (a > r).cumsum(1).argmin(1) + 1)
My pytorch is 1.2.0 and I find argmax always returns the last index in the above case(in CPU), so I didnot try to think about the consistency. I will follow your answer.