I have a 2d Tensor A of shape (d, d) and I want to get the indices of its maximal element. torch.argmax only returns a single index. For now I got the result doing the following, but it seems involuted.

vals, row_idx = A.max(0)
col_idx = vals.argmax(0)

And then, A[row_idx, col_idx] is the correct maximal value. Is there any more straightforward way to get this?

For anyone who stumbles in here and wonders which approach is faster (that provided by GeoffNN or ptrblck), the one-liner by ptrblck appears to be at least twice as fast.

In my (not rigorous) benchmarking, ptrblck’s code found the max indices of 100k random tensors in an average of 1.6 seconds, and the solution by GeoffNN found the max indices of 100k random tensors in an average of 3.5 seconds.

Suppose that now I have a 3D tensor of shape (batch_size, d, d) and I want to get for each 2D datapoint in the batch the index of its maximal element? I can’t figue out how to generalize the (x==torch.max(x)).nonzero() approach.