I want to do the feature selection based on the score map with two dimensional index. I can realize this without batch of tensor and batch also can be implemented with a for loop. But I want to know more elegant code to realize this. The following my code.
feature_map = torch.randn(8, 256, 64, 64) # [batch_size, C, H, W] score_map = torch.randn(8, 1, 64, 64) # [batch_size, 1, H, W] batch_select_feature =  for i in range(feature_map.size(0)): cur_score_map = score_map[i].detach().cpu().squeeze() val, idx = torch.topk(cur_score_map.flatten(), topk) topk_idx = np.array(np.unravel_index(idx.numpy(), cur_score_map.shape)).T select_feature = torch.cat([feature_map[i, :, x, y].unsqueeze(0) for x, y in topk_idx]) batch_select_feature.append(select_feature.unsqueeze(0)) batch_select_feature = torch.cat(batch_select_feature)