Hi Guo!
Your reasoning is correct. The thresholding operation is not (usefully)
differentiable, so autograd does not backpropagate through it.
To verify:
>>> import torch
>>> torch.__version__
'1.10.2'
>>> A = torch.tensor ([3., 4., 5., 6.], requires_grad = True)
>>> THRES = torch.tensor ([4.], requires_grad = True)
>>> A > THRES
tensor([False, False, True, True])
>>> A[A > THRES].sum().backward()
>>> A.grad
tensor([0., 0., 1., 1.])
>>> THRES.grad == None
True
The problem is that when A < THRES
, the derivative (with respect
to both A
and THRES
) is zero, and when A > THRES
, the derivative
is also zero (and when A == THRES
, the derivative is infinite or, if you
prefer, undefined). So even if autograd were to track the derivative,
you wouldn’t be able do anything useful with it.
As it stands, THRES
is not trainable (with a gradient-descent optimizer).
But it is perfectly reasonable to use differentiable “soft” thresholding:
>>> A.grad = None
>>> alpha = 2.5 # sharpness of "soft" step function
>>> (alpha * (A - THRES)).sigmoid()
tensor([0.0759, 0.5000, 0.9241, 0.9933], grad_fn=<SigmoidBackward0>)
>>> (A * (alpha * (A - THRES)).sigmoid()).sum().backward()
>>> A.grad
tensor([0.6016, 3.0000, 1.8004, 1.0930])
>>> THRES.grad
tensor([-4.0018])
With soft thresholding, THRES
has a perfectly well-defined (and useful)
gradient, and you could use that gradient to train / optimize THRES
.
(Note, in this example, A.grad
comes from two terms – one from the
first A
in the product A * (alpha * (A - THRES)).sigmoid()
, and
one from the second A
in the product – in the (A - THRES).sigmoid()
term.)
Best.
K. Frank