Training with threshold in PyTorch

I have a neural network, which produces a single value when excited with input. I need to use this value returned by the network to threshold another array. The result of this threshold operation is used to compute a loss function (the value of threshold is not known before hand and needs to be arrived at by training).
Following is an MWE

import torch

x = torch.randn(10, 1)  # Say this is the output of the network (10 is my batch size)
data_array = torch.randn(10, 2)  # This is the data I need to threshold
ground_truth = torch.randn(10, 2)  # This is the ground truth
mse_loss = torch.nn.MSELoss()  # Loss function

# Threshold
thresholded_vals = data_array * (data_array >= x)  # Returns zero in all places where the value is less than the threshold, the value itself otherwise

# Compute loss and gradients
loss = mse_loss(thresholded_vals, ground_truth)
loss.backward()  # Throws error here

Since the operation of thresholding returns a tensor array that is devoid of any gradients the backward() operation throws error.

How does one train a network in such a case?

Hi @learner47 I don’t think x can be differentiable here. Maybe try modifying the way you compute the threshold. Even then, your gradient is going to be 1 where data_array>=x and 0 elsewhere. Maybe this discussion might help clarify things How to make the parameter of torch.nn.Threshold learnable?


Hi Learner!

The thresholding operation is not (usefully) differentiable with respect
to x. To train x you should use a “soft,” differentiable thresholding
operation. You may use sigmoid() as a 'soft," differentiable step
function. Thus:

thresholded_vals = data_array * torch.sigmoid (data_array - x)

You may introduce a parameter to sharpen or smooth such a “soft”
step function:

thresholded_vals = data_array * torch.sigmoid (alpha * (data_array - x))

As you increase alpha towards infinity, the thresholding sharpens into
a hard step function.


K. Frank

1 Like