Hi guys, when using torch.where, it works as a threshold thus unable to pass the grad by torch.autograd.grad(ā¦). Instead, is there a way to keep to same graph and pass through to next differentiable operation?
For example, in a very simple code below
import torch
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
b = torch.tensor([[0.1],[0.9]])
c = torch.mm(a, b)
# ** position2 **: grad couldn't pass back here
c = torch.where(c>2, torch.tensor(1.0, requires_grad=True), torch.tensor(0.0, requires_grad=True))
# ** position1 **: the most place grad could arrive
loss = c.sum()
print(torch.autograd.grad(loss, a, allow_unused=True))
The c uses torch.where as a threshold to get binary values, of course the grad cannot arrive at position2. I just wonder is there some methods that we can keep cās grad in position1 and directly paste to position2, then keep calculating grads until arrive to a.
Above is only a hand-written example, when using build-in optimizer, after calculating loss and doing loss.backward(), is there any methods to let the backward overpass torch.where() like example above?