Alternatively, this code should also work:
x = torch.randn(10, 10) print((x==torch.max(x)).nonzero())