Calculate l2-norm to extract highest K regions of feature maps

Hi.
I need to calculate L2-norm for image representation tensor of img_fmap with [N, M, 25, 25] shape for each ij spatial location. then select the K highest for each 25X25 and equal to 1.

img_fmap = img_fmap.view(N, M, 25*25)
vector = torch.sum( img_fmap** 2, dim=2)
topk_idxes = torch.topk(vector, k=self.k)[1]

the result is in [N, M] shape.

Thanks!