Is torch.masked_select differentiable?

The question says it all. Thank you in advance.

Hi,

If you don’t get an error when you backprop, then it is :slight_smile:

A quick check shows that it is:

import torch

a = torch.rand(10, 10, requires_grad=True)
mask = torch.randn(10, 10).clamp(min=0).bool() # Generate random mask

out = torch.masked_select(a, mask)

out.sum().backward()
print(a.grad) # A Tensor of 0s and 1s
1 Like