Return indexes of Top K maximum values in a matrix

Given a matrix tensor with size [A,B], how to efficiently find the indexes of Top K maximum in this matrix?

e.g. matrix = torch.tensor[[4,2,7,1],[9,22,5,13],[6,4,8,25],[3,9,6,10],[4,1,6,14]].
When K = 4, the Top K maximum numbers are 22, 13, 25, 14, then the answer should return their indexes [1,1], [1,3], [2,3], [4,3].

1 Like

Hi Ian,

You can try the following solution.

import torch

x = torch.tensor([[4, 2, 7, 1], [9, 22, 5, 13], [6, 4, 8, 25], [3, 9, 6, 10], [4, 1, 6, 14]])
H, W = x.shape

x = x.view(-1)
K = 4
_, indices = x.topk(4)
two_d_indices = torch.cat(((indices // W).unsqueeze(1), (indices % W).unsqueeze(1)), dim=1)
print(two_d_indices)

Thanks

1 Like

Hi, Pranavan, thanks for your kind help.

1 Like