How do I pass grad through torch.where

Hi guys, when using torch.where, it works as a threshold thus unable to pass the grad by torch.autograd.grad(…). Instead, is there a way to keep to same graph and pass through to next differentiable operation?
For example, in a very simple code below

import torch

a = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
b = torch.tensor([[0.1],[0.9]])
c =, b)

# ** position2 **: grad couldn't pass back here
c = torch.where(c>2, torch.tensor(1.0, requires_grad=True), torch.tensor(0.0, requires_grad=True))
# ** position1 **: the most place grad could arrive
loss = c.sum()

print(torch.autograd.grad(loss, a, allow_unused=True))

The c uses torch.where as a threshold to get binary values, of course the grad cannot arrive at position2. I just wonder is there some methods that we can keep c’s grad in position1 and directly paste to position2, then keep calculating grads until arrive to a.

Above is only a hand-written example, when using build-in optimizer, after calculating loss and doing loss.backward(), is there any methods to let the backward overpass torch.where() like example above?


You can. write a custom Function whose backward will be the identity while the forward is your torch.where.
You can find here how to write such Function.

1 Like

Problem solved! Thanks so much

I’m struggling with the implementation.
This is my current code (I am using PyTorch 1.12.1):

class CustomWhere(torch.autograd.Function):
    def forward(ctx, input):
        result = torch.where(input > 0.1, 1.0, 0.0)
        return result
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        return torch.ones(input.shape[0])

def custom_where(x):
    result = CustomWhere.apply(x)
    return result

which I think implements what you mentioned. I’ve also tried

return grad_output * torch.ones(input.shape[0])


return torch.eye(input.shape[0])


return grad_output * torch.eye(input.shape[0])

I test all the approaches using gradcheck

from torch.autograd import gradcheck
input = torch.randn(20, dtype=torch.double,requires_grad=True)
test = gradcheck(custom_where, input, eps=1e-6, atol=1e-4)

but they all unfortunately fail with

GradcheckError: Jacobian mismatch for output 0 with respect to input 0

What am I doing wrong?


Your backward here is not an identity, it returns a constant 1 :slight_smile:
simply do return grad_output in there.
This is because the custom Function compute a vector jacobian product step. Not the jacobian itself.