# If a parameter is trainable in the following case?

Hi,

Suppose there is a tensor A=[3,4,5,6] with requires_grad=True, and a tensor THRES= also with requires_grad=True. Let B=A[A>THRES]. In this case, B allows backwards and it inherite gradients from A. I think the bool mask term [A>THRES] doesn’t maintain any gradient infomation. Despite the fact that THRES influence the final result, it doesn’t make any direct gradient contribution. My question is whether THRES is actually trainable or randomized around its initialization? Will it be trained as how reinforcement learning works?

Thank you!

Hi Guo!

Your reasoning is correct. The thresholding operation is not (usefully)
differentiable, so autograd does not backpropagate through it.

To verify:

``````>>> import torch
>>> torch.__version__
'1.10.2'
>>> A = torch.tensor ([3., 4., 5., 6.], requires_grad = True)
>>> THRES = torch.tensor ([4.], requires_grad = True)
>>> A > THRES
tensor([False, False,  True,  True])
>>> A[A > THRES].sum().backward()
tensor([0., 0., 1., 1.])
True
``````

The problem is that when `A < THRES`, the derivative (with respect
to both `A` and `THRES`) is zero, and when `A > THRES`, the derivative
is also zero (and when `A == THRES`, the derivative is infinite or, if you
prefer, undefined). So even if autograd were to track the derivative,
you wouldn’t be able do anything useful with it.

As it stands, `THRES` is not trainable (with a gradient-descent optimizer).

But it is perfectly reasonable to use differentiable “soft” thresholding:

``````>>> A.grad = None
>>> alpha = 2.5   # sharpness of "soft" step function
>>> (alpha * (A - THRES)).sigmoid()
>>> (A * (alpha * (A - THRES)).sigmoid()).sum().backward()
tensor([0.6016, 3.0000, 1.8004, 1.0930])
tensor([-4.0018])
``````

With soft thresholding, `THRES` has a perfectly well-defined (and useful)
gradient, and you could use that gradient to train / optimize `THRES`.

(Note, in this example, `A.grad` comes from two terms – one from the
first `A` in the product `A * (alpha * (A - THRES)).sigmoid()`, and
one from the second `A` in the product – in the `(A - THRES).sigmoid()`
term.)

Best.

K. Frank

2 Likes

That’s the most professional answer I could ever expect and It really helps. Thank you very much!