Get indices of the max of a 2D Tensor

Alternatively, this code should also work:

x = torch.randn(10, 10)
print((x==torch.max(x)).nonzero())
10 Likes