Get indices of the max of a 2D Tensor

I have a 2d Tensor A of shape (d, d) and I want to get the indices of its maximal element.
torch.argmax only returns a single index. For now I got the result doing the following, but it seems involuted.

vals, row_idx = A.max(0)
col_idx = vals.argmax(0)

And then, A[row_idx, col_idx] is the correct maximal value. Is there any more straightforward way to get this?

1 Like

Alternatively, this code should also work:

x = torch.randn(10, 10)
print((x==torch.max(x)).nonzero())
10 Likes

For anyone who stumbles in here and wonders which approach is faster (that provided by GeoffNN or ptrblck), the one-liner by ptrblck appears to be at least twice as fast.

In my (not rigorous) benchmarking, ptrblck’s code found the max indices of 100k random tensors in an average of 1.6 seconds, and the solution by GeoffNN found the max indices of 100k random tensors in an average of 3.5 seconds.

4 Likes

Is there any way to batch this?

Suppose that now I have a 3D tensor of shape (batch_size, d, d) and I want to get for each 2D datapoint in the batch the index of its maximal element? I can’t figue out how to generalize the (x==torch.max(x)).nonzero() approach.

3 Likes

Since argmax() gives you the index in a flattened tensor, you can infer the position in your 2D tensor from size of the last dimension. E.g. if argmax() returns 10 and you’ve got 4 columns, you know it’s on row 2, column 2. You can use Python’s divmod for this

a = torch.randn(10, 10)
row, col = divmod(a.argmax().item(), a.shape[1])

This can be extended to the batch case by flattening the last two dimensions, and iterating over the results to return several 2D indexes:

a = torch.randn(3, 2, 100)
flat_indexes = a.flatten(start_dim=-2).argmax(1)
[divmod(idx.item(), a.shape[-1]) for idx in flat_indexes]
2 Likes

I also have the requirement to batch-wise compute the 2D cordinates of the max values of each item.
A simple workaround may be:

# suppose the tensor is of shape (3,2,2), 
>>> a = torch.randn(3, 2, 2)
>>> a
tensor([[[ 0.1450, -1.3480],
         [-0.3339, -0.5133]],

        [[ 0.6867, -0.2972],
         [ 0.8768,  0.0844]],

        [[-2.3115, -0.4549],
         [-1.5074, -0.8706]]])

# then perform batch-wise max
>>> torch.stack([(a[i]==torch.max(a[i])).nonzero() for i in range(a.size(0))], dim=0)

tensor([[[0, 0]],

        [[1, 0]],

        [[0, 1]]])