Hi Learner!
The thresholding operation is not (usefully) differentiable with respect
to x
. To train x
you should use a “soft,” differentiable thresholding
operation. You may use sigmoid()
as a 'soft," differentiable step
function. Thus:
thresholded_vals = data_array * torch.sigmoid (data_array - x)
You may introduce a parameter to sharpen or smooth such a “soft”
step function:
thresholded_vals = data_array * torch.sigmoid (alpha * (data_array - x))
As you increase alpha
towards infinity, the thresholding sharpens into
a hard step function.
Best.
K. Frank