Differentiable top k masking

I would like to mask an input based on the top k masking values, naively doing something as in the following code. Since this is not differentiable, I wanted to ask if there’s a differentiable workaround to achieve the same thing? Thanks

import torch
top = 2

inp = torch.rand(5, 5, requires_grad=True)
mask = torch.rand(5, 5, requires_grad=True)

top_mask_indices = torch.topk(mask, top, dim=0).indices
masked_inp = torch.gather(inp, 0, top_mask_indices)

loss = torch.mean((masked_inp - torch.rand(top, 5)) ** 2)
loss.backward()

print(mask.grad)
print(inp.grad)
1 Like

Hi Damaggu!

When you mask with your top-k indices, you are keeping either 100% of the
terms that are included by the mask or 0% of the terms that are excluded.

The approach that comes to mind is to convert the elements of your mask
tensor to weight-like or probability-like values, and use them to weight the
terms in your loss function. The largest values in mask (that would be in,
say, the top k) would weight terms in your loss function the most heavily,
but values that didn’t make it into the top k would still contribute, and you
would avoid the non-differentiable discreteness of in-vs.-out of the top k.

Something like:

inp = torch.rand(5, 5, requires_grad=True)
weighted_mask = torch.rand(5, 5, requires_grad=True)   # probability-like
target = torch.rand(5, 5, requires_grad=True)

weighted_loss = torch.sum (weighted_mask * (inp - target)**2) / weighted_mask.sum()

Best.

K. Frank

Hello @damaggu , did you find a workaround?