Hi guys, when using torch.where, it works as a threshold thus unable to pass the grad by torch.autograd.grad(…). Instead, is there a way to keep to same graph and pass through to next differentiable operation?
For example, in a very simple code below
import torch
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
b = torch.tensor([[0.1],[0.9]])
c = torch.mm(a, b)
# ** position2 **: grad couldn't pass back here
c = torch.where(c>2, torch.tensor(1.0, requires_grad=True), torch.tensor(0.0, requires_grad=True))
# ** position1 **: the most place grad could arrive
loss = c.sum()
print(torch.autograd.grad(loss, a, allow_unused=True))
The c uses torch.where as a threshold to get binary values, of course the grad cannot arrive at position2. I just wonder is there some methods that we can keep c’s grad in position1 and directly paste to position2, then keep calculating grads until arrive to a.
Above is only a hand-written example, when using build-in optimizer, after calculating loss and doing loss.backward(), is there any methods to let the backward overpass torch.where() like example above?
You can. write a custom Function whose backward will be the identity while the forward is your torch.where.
You can find here how to write such Function.
Your backward here is not an identity, it returns a constant 1
simply do return grad_output in there.
This is because the custom Function compute a vector jacobian product step. Not the jacobian itself.