Topk extraction and reconstruction

Now I have a tensor ‘x’ with shape [b, c, d, h, w]. Now first I calculate the normalization of all the channels and the output tensor called ‘channel_norm’ is of size [b,1,d,h,w]. Then I sort the tensor ‘channel_norm’ along dimension ‘d’ and want to know the indices of the top k elements in each (b,h,w) position. Since ‘x’ and ‘channel_norm’ only differ in ‘c’, we can use the found top k ‘d’ indices in ‘channel_norm’ to find the corresponding ones in ‘x’. However, I do not how to do this.

After we find the top k elements in ‘x’, we will get a tensor ‘y’ of shape [b,c,k,h,w]. We do some convolutions on ‘y’, and output a tensor of shape [b,1,k,h,w]. Then what we want to do is to restore the k depths to the original position. In detail, suppose the top 2 elements are obtained in positions 1 and 5, and other elements are dropped, thus we have a generated vector of shape 2. Then after all the calculations, we want to reconstruct a vector of shape [b,1,d,h,w] by filling 0 in the positions we dropped before or interpolation. However, I do not know how to do this as well.

Hope I have clearly stated all the information. My code for the first part is attached below. Thank you so much!

    def forward(self, x):
        # x: [B,C,D,H,W]
        b, c, d, h, w = x.shape
        if self.hparams.top_depth != -1:
            channel_norm = torch.norm(x, p=float('inf'), dim=1, keepdim=True)  # channel_norm: [b,1,d,h,w]
            num_top_k = int(self.hparams.top_depth * d)
            _, indices_top_k = torch.topk(channel_norm, k=num_top_k, dim=2, largest=False, sorted=False)
            # _, indices_top_k: [b,1,num_top_k,h,w]
            # del _, channel_norm
            x = torch.index_select(x, dim=2, index=indices_top_k)
            x[indices_top_k]