Training with threshold in PyTorch

I have a neural network, which produces a single value when excited with input. I need to use this value returned by the network to threshold another array. The result of this threshold operation is used to compute a loss function (the value of threshold is not known before hand and needs to be arrived at by training).
Following is an MWE

import torch

x = torch.randn(10, 1)  # Say this is the output of the network (10 is my batch size)
data_array = torch.randn(10, 2)  # This is the data I need to threshold
ground_truth = torch.randn(10, 2)  # This is the ground truth
mse_loss = torch.nn.MSELoss()  # Loss function

# Threshold
thresholded_vals = data_array * (data_array >= x)  # Returns zero in all places where the value is less than the threshold, the value itself otherwise

# Compute loss and gradients
loss = mse_loss(thresholded_vals, ground_truth)
loss.backward()  # Throws error here

Since the operation of thresholding returns a tensor array that is devoid of any gradients the backward() operation throws error.

How does one train a network in such a case?

Hi @learner47 I don’t think x can be differentiable here. Maybe try modifying the way you compute the threshold. Even then, your gradient is going to be 1 where data_array>=x and 0 elsewhere. Maybe this discussion might help clarify things How to make the parameter of torch.nn.Threshold learnable?

Best!

Hi Learner!

The thresholding operation is not (usefully) differentiable with respect
to x. To train x you should use a “soft,” differentiable thresholding
operation. You may use sigmoid() as a 'soft," differentiable step
function. Thus:

thresholded_vals = data_array * torch.sigmoid (data_array - x)

You may introduce a parameter to sharpen or smooth such a “soft”
step function:

thresholded_vals = data_array * torch.sigmoid (alpha * (data_array - x))

As you increase alpha towards infinity, the thresholding sharpens into
a hard step function.

Best.

K. Frank

1 Like