How to get the max value with smallest index(if multiple max value) of a 2d tensor ?Thx.
If you just want the max value, then torch.max will do the trick.
If you specify the dimension over which to take the max, then it returns two tensors, the max values and their indices.
maxes, indices = torch.max(my_tensor, dim=0)
I don’t know whether it will take the smallest index value if there are two identical max values, though it the values are all floats then it is not very likely there will be two identical max values.
Hey! As this thread discusses, the index returned may be randomly selected from all max values:
Here is a link to a solution to this on stackoverflow:
Is it faster than doing it in the cpu?
argmax = torch.max(my_tensor.cpu(), dim=0)
as it is known that cpu computes the correct result
Hi! it’s not that
max do not produce a correct result, but that the result they produce may not necessarily be consistent over different runs and different devices, and is not guaranteed to give you first/last values, even if that is what you want.
The implementation in the link above, though, is guaranteed to give you either the first or last
Just checked and you are right.
Sorry I didn’t link the link before:
But according to that answer (which in current pytorch versions is no longer valid) CPU version of
torch.min always returned the correct argmax, which is not true anymore.
It can be done in the following way:
def argmax_first_and_last(a): b = torch.stack([torch.arange(a.shape)] * a.shape) max_values, _ = torch.max(a, dim=1) b[a != max_values[:, None]] = a.shape first_max, _ = torch.min(b, dim=1) b[a != max_values[:, None]] = -1 last_max, _ = torch.max(b, dim=1) return first_max, last_max a = torch.tensor([[ 0, 0, 2, 2, 0, 2, 0, 0], [-1, 0, 0, 1, 0, 0, 1, 0]]) first_max, last_max = argmax_first_and_last(a) print(first_max) # [2, 3] print(last_max) # [5, 6]