Patch Selection from Visual transformer Patches based on attention map

I am currently working with visual transformers and I have the input of [B, 64, 160] where 64 is the number of patches and 160 is the embedding dimension, I have an attention map of size [B, 64] computed with a nn.Linear layer from input patches. I want to select the top 32 patches from input with respect to the attention map( I want the output to be in the shape of [B, 32, 160] and contain the patches with the strongest scores). I tried several available solutions but none of them account for backpropagation on nn.Linear attention layer in other words selection process( slicing,, …) is non-differentiable with respect to indices, so PyTorch won’t compute any gradient for the attention layer. and slicing are differentiable and will provide the gradient to the selected items:

# select
x = torch.randn(10, requires_grad=True)
out =, 0, 1)
# <SelectBackward0 object at 0x7fdca5774ca0>

# tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.])

# slicing
x = torch.randn(10, requires_grad=True)
out = x[1:3]
# <SliceBackward0 object at 0x7fdca57183d0>
# tensor([0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
#         0.0000])