The parameters of the model with custom loss function doesn't upgraded thorough its learning over epochs

Hi Arul!

I am aware that this “straight-through estimator” trick is recommended
from time to time as a work-around for zero gradients (or a broken
computation graph) in the presence of thresholding.

Let me expand a little on what I see as its problems:

First, it is easy to implement and use “soft” thresholds for which you get
useful (and correct) gradients. So, what is the benefit of using something
that has demonstrable problems? (The fact that you might want to look at
a performance metric that involves a hard threshold is not a good argument.
You can look at multiple performance metrics and still choose to train with
a differentiable loss function. For example, in binary classification we often
train with BCEWithLogitsLoss but also use a prediction accuracy based
on hard thresholds as a performance metric.)

Second, the “straight-through estimator” trick gives incorrect gradients.
The gradient might still be “good enough,” and your network might still
train, but why introduce this inconsistency, when a fully-consistent
alternative exists? (I’m a big fan of using “good-enough” approximate
or surrogate gradients when there is good reason, such as when correct
analytical or numerical gradients are impractical to obtain, but I don’t see
a good reason in the case under discussion.)

To underscore the trouble you can get into, consider thresholding a
probability. Testing against 0.5 is a typical default choice, but perhaps
it makes more sense for your use case to threshold against, for example,
0.75. The “straight-through estimator” is fully blind to such a choice of
threshold. If you train once with an accuracy or intersection-over-union
calculated with a hard threshold of 0.5 as your loss function, and then
train again with the distinctly different loss function obtained by setting
the threshold to 0.75, using the “straight-through estimator” will yield
the exact same training and final weights even though you trained with
two different loss functions. This can hardly be considered good (even
if you deem it “good enough”).

The following script (set in the context of thresholding a probability-like
quantity) illustrates these failures of using the “straight-through estimator,”
as well as how the straightforward use of soft thresholds avoids them:

import torch
print (torch.__version__)

_ = torch.manual_seed (2022)

def logit_function (p):   # convert probability to logit (inverse of torch.sigmoid())
    return  (p / (1 - p)).log()

def soft_thresh (p, thresh, alpha = 2.5):    # soft threshold of probability p against thresh
    logits = logit_function (p)              # convert p to logit-space
    thresh_logit = logit_function (thresh)   # convert thresh to logit-space
    return torch.sigmoid (alpha * (logits - thresh_logit))   # alpha controls sharpness of soft step-function

def zero_one_match (input, target):   # use as simple loss function
    input = 2 * input - 1     # scale to [-1, 1]
    target = 2 * target - 1   # scale to [-1, 1]
    return (input * target).sum()     # does input match target?

inputs = torch.rand (12, requires_grad = True)
labels = torch.randint (2, (12,)).float()

print ('"straight-through estimator" thresholding -- gradients wrong and independent of threshold:')
for  threshold in torch.arange (0.2, 0.9, 0.1):
    print ('threshold: %.2f' % threshold.item())
    inputs.grad = None
    thresholded_inputs = torch.where (inputs < threshold, 0., 1.)
    trick_inputs = (thresholded_inputs - inputs).detach() + inputs   # dubious trick   
    loss = zero_one_match (trick_inputs, labels)
    loss.backward()
    print ('loss:', loss)                 # loss depends on threshold
    print ('inputs.grad:', inputs.grad)   # gradients are wrong and do NOT depend on threshold

print ('soft thresholding -- gradients correct and depend on threshold:')
for  threshold in torch.arange (0.2, 0.9, 0.1):
    print ('threshold: %.2f' % threshold.item())
    inputs.grad = None
    thresholded_inputs = soft_thresh (inputs, threshold)
    loss = zero_one_match (thresholded_inputs, labels)
    loss.backward()
    print ('loss:', loss)                 # loss depends on threshold
    print ('inputs.grad:', inputs.grad)   # gradients are correct and DO depend on threshold

print ('very sharp soft thresholding -- loss close to hard threshold and gradients mostly small:')
for  threshold in torch.arange (0.2, 0.9, 0.1):
    print ('threshold: %.2f' % threshold.item())
    inputs.grad = None
    thresholded_inputs = soft_thresh (inputs, threshold, alpha = 30.0)
    loss = zero_one_match (thresholded_inputs, labels)
    loss.backward()
    print ('loss:', loss)                 # loss is nearly the same as thresholded loss
    print ('inputs.grad:', inputs.grad)   # gradients are mostly nearly zero (which is correct)

Here is its output:

1.11.0
"straight-through estimator" thresholding -- gradients wrong and independent of threshold:
threshold: 0.20
loss: tensor(2., grad_fn=<SumBackward0>)
inputs.grad: tensor([ 2.,  2.,  2.,  2.,  2.,  2., -2., -2., -2.,  2., -2.,  2.])
threshold: 0.30
loss: tensor(0., grad_fn=<SumBackward0>)
inputs.grad: tensor([ 2.,  2.,  2.,  2.,  2.,  2., -2., -2., -2.,  2., -2.,  2.])
threshold: 0.40
loss: tensor(-6., grad_fn=<SumBackward0>)
inputs.grad: tensor([ 2.,  2.,  2.,  2.,  2.,  2., -2., -2., -2.,  2., -2.,  2.])
threshold: 0.50
loss: tensor(-4., grad_fn=<SumBackward0>)
inputs.grad: tensor([ 2.,  2.,  2.,  2.,  2.,  2., -2., -2., -2.,  2., -2.,  2.])
threshold: 0.60
loss: tensor(-4., grad_fn=<SumBackward0>)
inputs.grad: tensor([ 2.,  2.,  2.,  2.,  2.,  2., -2., -2., -2.,  2., -2.,  2.])
threshold: 0.70
loss: tensor(-4., grad_fn=<SumBackward0>)
inputs.grad: tensor([ 2.,  2.,  2.,  2.,  2.,  2., -2., -2., -2.,  2., -2.,  2.])
threshold: 0.80
loss: tensor(-2., grad_fn=<SumBackward0>)
inputs.grad: tensor([ 2.,  2.,  2.,  2.,  2.,  2., -2., -2., -2.,  2., -2.,  2.])
soft thresholding -- gradients correct and depend on threshold:
threshold: 0.20
loss: tensor(0.7202, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 1.5833e+00,  4.5306e-03,  4.8463e-02,  1.8226e+00,  7.4010e-01,
         2.2448e+00, -3.2932e-02, -3.7912e-02, -8.2272e-01,  1.6953e-01,
        -1.2026e-01,  6.7104e+00])
threshold: 0.30
loss: tensor(-0.9067, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 3.9940,  0.0174,  0.1846,  4.3443,  0.1934,  4.8582, -0.1259, -0.1448,
        -2.5157,  0.6238, -0.4489,  5.7995])
threshold: 0.40
loss: tensor(-2.4516, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 5.2245,  0.0525,  0.5421,  5.2478,  0.0642,  5.1843, -0.3738, -0.4284,
        -4.5207,  1.6746, -1.2504,  2.9535])
threshold: 0.50
loss: tensor(-3.4529, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 4.0005,  0.1443,  1.3931,  3.7471,  0.0233,  3.3524, -0.9866, -1.1211,
        -4.8686,  3.4935, -2.8306,  1.2545])
threshold: 0.60
loss: tensor(-3.8801, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 2.0784,  0.3947,  3.2038,  1.8651,  0.0085,  1.5764, -2.4253, -2.6978,
        -3.2462,  5.2203, -4.9285,  0.4836])
threshold: 0.70
loss: tensor(-3.8379, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 8.0462e-01,  1.1643e+00,  5.9590e+00,  7.0719e-01,  2.8021e-03,
         5.8205e-01, -5.2944e+00, -5.5838e+00, -1.4389e+00,  4.6790e+00,
        -5.4657e+00,  1.6403e-01])
threshold: 0.80
loss: tensor(-3.2981, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 2.2210e-01,  4.0748e+00,  6.2493e+00,  1.9357e-01,  7.2826e-04,
         1.5763e-01, -7.6044e+00, -7.1622e+00, -4.2092e-01,  2.0858e+00,
        -2.9342e+00,  4.3002e-02])
very sharp soft thresholding -- loss close to hard threshold and gradients mostly small:
threshold: 0.20
loss: tensor(1.9908, grad_fn=<SumBackward0>)
inputs.grad: tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.1906e-26, 0.0000e+00,
        -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 1.5472e+00])
threshold: 0.30
loss: tensor(-0.0006, grad_fn=<SumBackward0>)
inputs.grad: tensor([7.4772e-04, 0.0000e+00, 0.0000e+00, 4.7911e-03, 2.0800e-33, 8.0937e-02,
        -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 6.9690e-03])
threshold: 0.40
loss: tensor(-5.0750, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 5.8589e+01,  0.0000e+00,  0.0000e+00,  1.9767e+01,  0.0000e+00,
         1.4525e+00, -0.0000e+00, -0.0000e+00, -9.1353e-02,  0.0000e+00,
        -0.0000e+00,  1.2207e-08])
threshold: 0.50
loss: tensor(-4.0271, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 7.7379e-04,  0.0000e+00,  0.0000e+00,  1.2311e-04,  0.0000e+00,
         7.6607e-06, -0.0000e+00, -0.0000e+00, -3.2303e+00,  3.0786e-05,
        -0.0000e+00,  6.3663e-14])
threshold: 0.60
loss: tensor(-4.0298, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 4.0354e-09,  0.0000e+00,  0.0000e+00,  6.4200e-10,  0.0000e+00,
         3.9951e-11, -0.0000e+00, -0.0000e+00, -1.7313e-05,  3.8291e+00,
        -3.5555e-02,  3.3201e-19])
threshold: 0.70
loss: tensor(-4.0264, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 7.0684e-15,  0.0000e+00,  4.1928e-02,  1.1245e-15,  0.0000e+00,
         6.9978e-17, -1.3088e-04, -1.0457e-03, -3.0326e-11,  2.9592e-02,
        -3.5225e+00,  5.8155e-25])
threshold: 0.80
loss: tensor(-2.5128, grad_fn=<SumBackward0>)
inputs.grad: tensor([ 6.7116e-22,  0.0000e+00,  2.4303e-01,  1.0678e-22,  0.0000e+00,
         6.6446e-24, -6.3990e+01, -1.0634e+01, -2.8795e-18,  2.8104e-09,
        -3.4346e-07,  5.5219e-32])

Best.

K. Frank