Is torch.masked_select differentiable?

The question says it all. Thank you in advance.

Hi,

If you don’t get an error when you backprop, then it is :slight_smile:

A quick check shows that it is:

import torch

a = torch.rand(10, 10, requires_grad=True)
mask = torch.randn(10, 10).clamp(min=0).bool() # Generate random mask

out = torch.masked_select(a, mask)

out.sum().backward()
print(a.grad) # A Tensor of 0s and 1s
1 Like

Hi, what about torch.sort(), which does not give an error, but the sort operation should be Non-differentiable.

It depends how you look at it :smiley:
It is differentiable if you consider it is only shuffling the elements around: the for each output propagates to exactly the input it came from.
There are also research papers discussing taking the “getting the indices for shuffling” part in the gradient computation (I don’t have the link on hand but you can find it only when looking for differentiable sorting).

Thank you so much! I will search for the related papers you mentioned.