I have a 2d Tensor A of shape (d, d) and I want to get the indices of its maximal element. torch.argmax only returns a single index. For now I got the result doing the following, but it seems involuted.

vals, row_idx = A.max(0)
col_idx = vals.argmax(0)

And then, A[row_idx, col_idx] is the correct maximal value. Is there any more straightforward way to get this?

For anyone who stumbles in here and wonders which approach is faster (that provided by GeoffNN or ptrblck), the one-liner by ptrblck appears to be at least twice as fast.

In my (not rigorous) benchmarking, ptrblck’s code found the max indices of 100k random tensors in an average of 1.6 seconds, and the solution by GeoffNN found the max indices of 100k random tensors in an average of 3.5 seconds.

Suppose that now I have a 3D tensor of shape (batch_size, d, d) and I want to get for each 2D datapoint in the batch the index of its maximal element? I can’t figue out how to generalize the (x==torch.max(x)).nonzero() approach.

Since argmax() gives you the index in a flattened tensor, you can infer the position in your 2D tensor from size of the last dimension. E.g. if argmax() returns 10 and you’ve got 4 columns, you know it’s on row 2, column 2. You can use Python’s divmod for this

a = torch.randn(10, 10)
row, col = divmod(a.argmax().item(), a.shape[1])

This can be extended to the batch case by flattening the last two dimensions, and iterating over the results to return several 2D indexes:

a = torch.randn(3, 2, 100)
flat_indexes = a.flatten(start_dim=-2).argmax(1)
[divmod(idx.item(), a.shape[-1]) for idx in flat_indexes]