I need to calculate L2-norm for image representation tensor of img_fmap with [N, M, 25, 25] shape for each ij spatial location. then select the K highest for each 25X25 and equal to 1.
img_fmap = img_fmap.view(N, M, 25*25)
vector = torch.sum( img_fmap** 2, dim=2)
topk_idxes = torch.topk(vector, k=self.k)
the result is in [N, M] shape.