How to use torch.topk on multiple dimensional tensor

I want to do the feature selection based on the score map with two dimensional index. I can realize this without batch of tensor and batch also can be implemented with a for loop. But I want to know more elegant code to realize this. The following my code.

feature_map = torch.randn(8, 256, 64, 64)  # [batch_size, C, H, W]
score_map = torch.randn(8, 1, 64, 64)      # [batch_size, 1, H, W]

batch_select_feature = []
for i in range(feature_map.size(0)):
    cur_score_map = score_map[i].detach().cpu().squeeze()
    val, idx = torch.topk(cur_score_map.flatten(), topk)
    topk_idx = np.array(np.unravel_index(idx.numpy(), cur_score_map.shape)).T 
    select_feature = torch.cat([feature_map[i, :, x, y].unsqueeze(0) for x, y in topk_idx]) 
    batch_select_feature.append(select_feature.unsqueeze(0))
batch_select_feature = torch.cat(batch_select_feature)
1 Like

Any solution? Thanks