I want to do the feature selection based on the score map with two dimensional index. I can realize this without batch of tensor and batch also can be implemented with a for loop. But I want to know more elegant code to realize this. The following my code.
feature_map = torch.randn(8, 256, 64, 64) # [batch_size, C, H, W]
score_map = torch.randn(8, 1, 64, 64) # [batch_size, 1, H, W]
batch_select_feature = []
for i in range(feature_map.size(0)):
cur_score_map = score_map[i].detach().cpu().squeeze()
val, idx = torch.topk(cur_score_map.flatten(), topk)
topk_idx = np.array(np.unravel_index(idx.numpy(), cur_score_map.shape)).T
select_feature = torch.cat([feature_map[i, :, x, y].unsqueeze(0) for x, y in topk_idx])
batch_select_feature.append(select_feature.unsqueeze(0))
batch_select_feature = torch.cat(batch_select_feature)