Get indices of the max of a 2D Tensor

I have a 2d Tensor A of shape (d, d) and I want to get the indices of its maximal element.
torch.argmax only returns a single index. For now I got the result doing the following, but it seems involuted.

vals, row_idx = A.max(0)
col_idx = vals.argmax(0)

And then, A[row_idx, col_idx] is the correct maximal value. Is there any more straightforward way to get this?

1 Like

Alternatively, this code should also work:

x = torch.randn(10, 10)
print((x==torch.max(x)).nonzero())
8 Likes

For anyone who stumbles in here and wonders which approach is faster (that provided by GeoffNN or ptrblck), the one-liner by ptrblck appears to be at least twice as fast.

In my (not rigorous) benchmarking, ptrblck’s code found the max indices of 100k random tensors in an average of 1.6 seconds, and the solution by GeoffNN found the max indices of 100k random tensors in an average of 3.5 seconds.

4 Likes

Is there any way to batch this?

Suppose that now I have a 3D tensor of shape (batch_size, d, d) and I want to get for each 2D datapoint in the batch the index of its maximal element? I can’t figue out how to generalize the (x==torch.max(x)).nonzero() approach.

2 Likes