Only preserve the top k channels in channel attention layer

Hi there!
I wish to implement a customized channel attention layer. Concretely, for a (B,C,H,W) tensor, the channel attention layer will return an attention tensor of shape (B,C,1,1), and the input tensor is multiplied with this vector-like tensor.
What I wish to add is selecting the most important k layers according to the attention vector, and the resulting tensor will be of shape (B,k,H,W).
For the implementation, the only thing in my mind is that

  1. sorting the attention
  2. selecting the top k by operation like [:,:k,:,:] something

But my concern is that will autograd still work? Because I can’t think about a way that any differentialable function can acheive the “[:k]” operation.
Can my naive implementation work? If yes, can anyone give me some help about how it works, like how the gradient pass, etc. If not, can anyone give me some suggestions?
Thank you so much!!!
My idea is based on implementation here: