How does torch.where work in autograd

Hi, sorry if the answer is trivial but I can’t seem to get how autograd differentiates torch.where operations where the mask used by torch.where is also derived from an input.

Hi

That’s a good question.

Let’s walk through three cases to see how gradients behave with torch.where().
Reminder: gradients in PyTorch flow only to differentiable inputs (like model weights or tensors with requires_grad=True). Here, we want to know how torch.where routes the gradient

Suppose we have

x = [1,2,3,4]
y = [-1,-2,3,4]

Case-1: Constant

z = torch.where(x>2,1,0)

output:[0,0,1,1].
Here, the result is constant with respect to both x and y. There’s no functional relationship between inputs and output, so no gradient can flow

Case-2: Function of x

z = torch.where(x>2,x,0)
output:[0,0,3,4].

Now, if the condition is True, the corresponding element of x is selected.

  • For those positions, gradients flow back to x
  • For the others (masked out), the gradient is zero

So backpropagation works, but only through the “active” elements of x

Case-3: function of both x and y

z = torch.where(x>2,x,y)
output: will be [-1,-2,3,4].

Here:

  • Where the condition is True, the gradient flows into x
  • Where the condition is False, the gradient flows into y

Thus, both x and y can receive gradients, depending on the mask.

Mathematically, the operator is deterministic, therefore, we can compute the gradient.To my knowledge, only stochastic operators (like sampling) break down the flow of gradients!

Hope it helps!

Sorry if I misunderstand your question, but are you asking how torch can differentiate the result of torch.where with respect to its condition parameter? It doesn’t: condition is a BoolTensor, that will never require grad. So it’s considered as a constant.

However, as explained by @Arunprakash-A, torch can differentiate the result of torch.where with respect to either input and / or other (depending on their requires_grad field). And this differentiation depends on the value of condition.

1 Like