Training with threshold in PyTorch

Hi Learner!

The thresholding operation is not (usefully) differentiable with respect
to x. To train x you should use a “soft,” differentiable thresholding
operation. You may use sigmoid() as a 'soft," differentiable step
function. Thus:

thresholded_vals = data_array * torch.sigmoid (data_array - x)

You may introduce a parameter to sharpen or smooth such a “soft”
step function:

thresholded_vals = data_array * torch.sigmoid (alpha * (data_array - x))

As you increase alpha towards infinity, the thresholding sharpens into
a hard step function.

Best.

K. Frank

2 Likes