How do I pass grad through torch.where

Hi guys, when using torch.where, it works as a threshold thus unable to pass the grad by torch.autograd.grad(ā€¦). Instead, is there a way to keep to same graph and pass through to next differentiable operation?
For example, in a very simple code below

import torch

a = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
b = torch.tensor([[0.1],[0.9]])
c = torch.mm(a, b)

# ** position2 **: grad couldn't pass back here
c = torch.where(c>2, torch.tensor(1.0, requires_grad=True), torch.tensor(0.0, requires_grad=True))
# ** position1 **: the most place grad could arrive
loss = c.sum()

print(torch.autograd.grad(loss, a, allow_unused=True))

The c uses torch.where as a threshold to get binary values, of course the grad cannot arrive at position2. I just wonder is there some methods that we can keep cā€™s grad in position1 and directly paste to position2, then keep calculating grads until arrive to a.

Above is only a hand-written example, when using build-in optimizer, after calculating loss and doing loss.backward(), is there any methods to let the backward overpass torch.where() like example above?

Hi,

You can. write a custom Function whose backward will be the identity while the forward is your torch.where.
You can find here how to write such Function.

1 Like

Problem solved! Thanks so much