Hi guys, when using * torch.where*, it works as a threshold thus unable to pass the grad by torch.autograd.grad(ā¦). Instead, is there a way to keep to same graph and pass through to next differentiable operation?

For example, in a very simple code below

```
import torch
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
b = torch.tensor([[0.1],[0.9]])
c = torch.mm(a, b)
# ** position2 **: grad couldn't pass back here
c = torch.where(c>2, torch.tensor(1.0, requires_grad=True), torch.tensor(0.0, requires_grad=True))
# ** position1 **: the most place grad could arrive
loss = c.sum()
print(torch.autograd.grad(loss, a, allow_unused=True))
```

The **c** uses torch.where as a threshold to get binary values, of course the grad cannot arrive at **position2**. I just wonder is there some methods that we can keep cās grad in **position1** and directly paste to **position2**, then keep calculating grads until arrive to **a**.

Above is only a hand-written example, when using build-in optimizer, after calculating loss and doing loss.backward(), is there any methods to let the backward **overpass** torch.where() like example above?