Given a matrix tensor with size [A,B], how to efficiently find the indexes of Top K maximum in this matrix?

e.g. matrix = torch.tensor[[4,2,7,1],[9,22,5,13],[6,4,8,25],[3,9,6,10],[4,1,6,14]].
When K = 4, the Top K maximum numbers are 22, 13, 25, 14, then the answer should return their indexes [1,1], [1,3], [2,3], [4,3].