Hi, sorry if the answer is trivial but I can’t seem to get how autograd differentiates torch.where operations where the mask used by torch.where is also derived from an input.
Hi
That’s a good question.
Let’s walk through three cases to see how gradients behave with torch.where().
Reminder: gradients in PyTorch flow only to differentiable inputs (like model weights or tensors with requires_grad=True). Here, we want to know how torch.where routes the gradient
Suppose we have
x = [1,2,3,4]
y = [-1,-2,3,4]
Case-1: Constant
z = torch.where(x>2,1,0)
output:[0,0,1,1].
Here, the result is constant with respect to both x and y. There’s no functional relationship between inputs and output, so no gradient can flow
Case-2: Function of x
z = torch.where(x>2,x,0)
output:[0,0,3,4].
Now, if the condition is True, the corresponding element of x is selected.
- For those positions, gradients flow back to
x - For the others (masked out), the gradient is zero
So backpropagation works, but only through the “active” elements of x
Case-3: function of both x and y
z = torch.where(x>2,x,y)
output: will be [-1,-2,3,4].
Here:
- Where the condition is
True, the gradient flows intox - Where the condition is
False, the gradient flows intoy
Thus, both x and y can receive gradients, depending on the mask.
Mathematically, the operator is deterministic, therefore, we can compute the gradient.To my knowledge, only stochastic operators (like sampling) break down the flow of gradients!
Hope it helps!
Sorry if I misunderstand your question, but are you asking how torch can differentiate the result of torch.where with respect to its condition parameter? It doesn’t: condition is a BoolTensor, that will never require grad. So it’s considered as a constant.
However, as explained by @Arunprakash-A, torch can differentiate the result of torch.where with respect to either input and / or other (depending on their requires_grad field). And this differentiation depends on the value of condition.