Hi guys, when using torch.where, it works as a threshold thus unable to pass the grad by torch.autograd.grad(…). Instead, is there a way to keep to same graph and pass through to next differentiable operation?
For example, in a very simple code below
import torch a = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) b = torch.tensor([[0.1],[0.9]]) c = torch.mm(a, b) # ** position2 **: grad couldn't pass back here c = torch.where(c>2, torch.tensor(1.0, requires_grad=True), torch.tensor(0.0, requires_grad=True)) # ** position1 **: the most place grad could arrive loss = c.sum() print(torch.autograd.grad(loss, a, allow_unused=True))
The c uses torch.where as a threshold to get binary values, of course the grad cannot arrive at position2. I just wonder is there some methods that we can keep c’s grad in position1 and directly paste to position2, then keep calculating grads until arrive to a.
Above is only a hand-written example, when using build-in optimizer, after calculating loss and doing loss.backward(), is there any methods to let the backward overpass torch.where() like example above?