Hi Luc!
The short story is use sigmoid()
to create a “soft” mask.
The problem is that the threshold-masking operation is not (usefully)
differentiable.
As you vary threshold
over some range, the same masked elements
of your tensor are kept, so loss
is constant over this range of
threshold
. Mathematically, over this range, the gradient of loss
is zero, which isn’t useful for gradient descent (and pytorch is smart
enough not to compute this not-useful gradient).
At the specific (discrete) values of threshold
where the set of masked
elements changes, the gradient is mathematically undefined (or inf
, if
you prefer). This is also not useful (and in any event, is a set of measure
zero).
Instead, you want to smoothly turn the masked elements on and off.
Consider:
>>> import torch
>>> print (torch.__version__)
2.3.1
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> X = torch.randn (10, 3, requires_grad = True)
>>> threshold = torch.tensor (1.5, requires_grad = True)
>>>
>>> loss = X[X.norm (dim = 1) < threshold].norm (dim = 1).mean()
>>> loss
tensor(1.0206, grad_fn=<MeanBackward0>)
>>>
>>> hard = 10.0 # larger values make the mask "harder"
>>> X_norm = X.norm (dim = 1)
>>> soft_mask = torch.sigmoid (hard * (threshold - X_norm))
>>> soft_mask
tensor([9.0025e-03, 9.9987e-01, 8.5999e-01, 9.0971e-07, 9.6722e-01, 9.2340e-08,
9.6193e-03, 9.9644e-01, 9.9148e-01, 9.8524e-01],
grad_fn=<SigmoidBackward0>)
>>>
>>> lossB = (soft_mask * X_norm).sum() / soft_mask.sum() # weighted mean
>>> lossB
tensor(1.0156, grad_fn=<DivBackward0>)
>>>
>>> lossB.backward()
>>> threshold.grad
tensor(0.1020)
Here we smoothly turn the masked elements on and off by multiplying
them with soft_mask
that contains values that are (usually) close to
zero or one.
As the parameter hard
is increased, in principle to infinity, the elements
of soft_mask
become zero and one, soft_mask
becomes effectively
the same as your “hard” boolean mask, and lossB
becomes equal to
loss
.
However, the larger you make hard
, the smaller the range becomes
over which the gradient of loss with respect to threshold
differs from
zero enough to be of practical use. (And even is when the gradient is
mathematically non-zero, it can underflow to zero.)
Think of soft_mask
as a differentiable proxy for (approximation to)
your hard boolean mask.
It’s up to your use case how hard or soft you want soft_mask
to be and
what value you should use for the parameter hard
.
Best.
K. Frank