I am currently working with visual transformers and I have the input of [B, 64, 160] where 64 is the number of patches and 160 is the embedding dimension, I have an attention map of size [B, 64] computed with a nn.Linear layer from input patches. I want to select the top 32 patches from input with respect to the attention map( I want the output to be in the shape of [B, 32, 160] and contain the patches with the strongest scores). I tried several available solutions but none of them account for backpropagation on nn.Linear attention layer in other words selection process( slicing, torch.select, …) is non-differentiable with respect to indices, so PyTorch won’t compute any gradient for the attention layer.
torch.select and slicing are differentiable and will provide the gradient to the selected items:
# select x = torch.randn(10, requires_grad=True) out = torch.select(x, 0, 1) print(out.grad_fn) # <SelectBackward0 object at 0x7fdca5774ca0> out.backward() print(x.grad) # tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]) # slicing x = torch.randn(10, requires_grad=True) out = x[1:3] print(out.grad_fn) # <SliceBackward0 object at 0x7fdca57183d0> out.mean().backward() print(x.grad) # tensor([0.0000, 0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, # 0.0000])