How to get the max value with smallest index

How to get the max value with smallest index(if multiple max value) of a 2d tensor ?Thx.

2 Likes

If you just want the max value, then torch.max will do the trick.
If you specify the dimension over which to take the max, then it returns two tensors, the max values and their indices.

maxes, indices = torch.max(my_tensor, dim=0)

I don’t know whether it will take the smallest index value if there are two identical max values, though it the values are all floats then it is not very likely there will be two identical max values.

Hey! As this thread discusses, the index returned may be randomly selected from all max values:
https://discuss.pytorch.org/t/function-torch-max-return-indices-inconsistency-between-cup-and-gpu/1890/2

Here is a link to a solution to this on stackoverflow:
https://stackoverflow.com/questions/55139801/index-selection-in-case-of-conflict-in-pytorch-argmax/55146371#55146371

1 Like

Is it faster than doing it in the cpu?

argmax = torch.max(my_tensor.cpu(), dim=0)[1]

as it is known that cpu computes the correct result

Hi! it’s not that argmax or max do not produce a correct result, but that the result they produce may not necessarily be consistent over different runs and different devices, and is not guaranteed to give you first/last values, even if that is what you want.

The implementation in the link above, though, is guaranteed to give you either the first or last argmax index.

Just checked and you are right.

Sorry I didn’t link the link before:

But according to that answer (which in current pytorch versions is no longer valid) CPU version of torch.max and torch.min always returned the correct argmax, which is not true anymore.

It can be done in the following way:

def argmax_first_and_last(a):
    b = torch.stack([torch.arange(a.shape[1])] * a.shape[0])
    max_values, _ = torch.max(a, dim=1)
    b[a != max_values[:, None]] = a.shape[1]
    first_max, _ = torch.min(b, dim=1)
    b[a != max_values[:, None]] = -1
    last_max, _ = torch.max(b, dim=1)
    return first_max, last_max
    
a = torch.tensor([[ 0, 0, 2, 2, 0, 2, 0, 0],
                  [-1, 0, 0, 1, 0, 0, 1, 0]])

first_max, last_max = argmax_first_and_last(a)
print(first_max) # [2, 3]
print(last_max) # [5, 6]