eps = 1e-7
z = torch.zeros((), device=a.device, dtype=a.dtype)
a_new = torch.where((a - 1.5).abs() < eps, a, z)
probably is what you want.
Note that a == 1.5 might not be a good condition to test for due to the accuracy limits of floating-point computation.
You can assign the new thing to a again if you want.
Oh sorry, I had to declare z as the nn.Parameter and make requires_grad=True for z.
That solved the problem.
But threshold.grad
returns 0. So I am unclear if the gradient is being computed for threshold even though its requires_grad parameter is True. If not, am I interpreting it wrong?
Is there a way we can have a learnable threshold?