Gradients of torch.where

Hello,

I am trying to calculate gradients of a function that uses torch.where, however it results in unexpected gradients. I basically use it to choose between some real case, complex case and limit case where some of the cases will have a Nan gradient for some specific input. For simplicity consider the following example:

def f1(x):
    return 0/x

def f2(x):
    return x

def g(x):
    r1 = f1(x)
    r2 = f2(x)
    return torch.where(x == 0.0, r2, r1)

x = torch.autograd.Variable(torch.zeros(1), requires_grad=True)
print(torch.autograd.grad(f(x), x))
>> (tensor([nan]),)
print(torch.autograd.grad(f2(x), x))
>> (tensor([1.]),)

I would expect that the gradient of f(x) to be the same as f2(x) but they are clearly different. This problem only happens when one of the functions has a Nan gradient i.e. if I change f1 to something like x**2 I get the correct result. I do not know how gradients are propagated through the torch.where function, but if it is something like

torch.where(condition, tensor1, tensor2) = condition * tensor1 + (1-condition) * tensor2

then it make sense that the resulting gradient is Nan, even if condition = 0 and only tensor2 has a Nan gradient.

Is this the intended behavior of torch.where or is there missing some Nan checking missing? If it is intended, it there a way to get around it?

Thank you

While it is a bit confusing at first sight, there is little that can be done here, and in particular, where is working as expected.
When you instrument this as

def f1(x):
    return 0/x

def f2(x):
    return x.clone()

x = torch.zeros(1, requires_grad=True)
r1 = f1(x)
r1.retain_grad()
r2 = f2(x)
r2.retain_grad()
c = torch.where(x == 0, r2, r1)
c.backward()
print(r1.grad, r2.grad, x.grad)

Here you see that r1.grad is 0, r2.grad is 1 and x.grad is NaN.

The gradient x.grad is NaN because even when you feed a grad of 0 to the backward of f1, you’ll get NaN as the gradient of f1’s input.
The gradient of x is the sum of the input gradients of f1 and f2, i.e. NaN + 1 = NaN.

Best regards

Thomas

2 Likes

It seems connected to a thing I noted a while ago: Bug or feature? NaNs influence other variables in backprop
In that case there was a 0* inside the code of PyTorch that did 0*NaN and obtained NaN. The workaround for me has been to substitute NaNs with zeros.

Anyway, even if it’s expected behavior, I don’t feel the PyTorch implementation is the most reasonable

No, this is wrong. (Considere torch.where(..., r1, r1). With your argument, the grad should double, but clearly this is just the identity function.)

Your example does not change the fact that the backward of “use the same tensor several times” is the addition of the gradients of the various uses.

In your example, conceptually, the gradient of the loss w.r.t. output of where is split between the second argument where the condition in the first argument is True/nonzero (which goes to the first r1 use) and the third argument (which goes to the second r1 use). Now because you put in the same tensor into the two arguments, these two uses will be added to get the identity, but they were masked copies of the gradient w.r.t. the output of where.

Just because you could simplify the expression does not mean that the autograd mechanics change for the operators you give. Of course, if you run this through a compiler that optimizes the case, the gradient computation might also change.

So you say that the sum is w.r.t. to the masked tensor?

What I expect for the grad: For some given:

c = where(mask, a, b)

For the grad of the loss w.r.t. a and b, specified as da/db, in dependence of the grad of the loss w.r.t. c, specified as dc:

da = where(mask, dc, 0.)
db = where(~mask, dc, 0.)

For the given earlier example with mask = x == 0 (== True), a=r2, b=r1, it means:

da = dc = 1
db = 0

I think you say that:

dx = f2.backward(da) + f1.backward(db)

Ah, I think I see what you mean now. This will give:

dx = 1. * da + 0.0 * log(x) * db
   = 1. + 0. * NaN
   = NaN

But speaking mathematically, this computation is wrong. The grad of c w.r.t. x is clearly 1. With db = 0, it does not matter that the grad of r1 w.r.t. x is undefined (NaN).

I wonder if autograd should maybe fix this? The application of the chain rule is wrong: when grad_output is 0 somewhere, grad_input must also be 0. Currently it does not do that.

As a mathematician, I would say that the computation is mathematically defined as being IEEE 754 operations, including NaN trumping everything else.

But so leaving aside my reservations about going around and saying “you’re wrong” at people trying to help understand why things are the way they are, it seems that you are trying to start a conversation about the autograd behaviour involved: People have suggested using “masking zeros” for gradients before, it has not gone anywhere because a departure from IEEE semantics is quite expensive. Also, it is much less universally useful than it looks because not all operations have a pointwise backwards and the conventional wisdom is to avoid generating NaNs rather than try to deal with them.

Masked tensors don’t quite do what you want (they don’t add) but maybe you could define your own tensor subclass that takes inspiration there to achieve things.

Thanks for the clarification. I think I understand the reasoning, and I came to the same conclusion, that currently it is best to avoid generating NaNs.

Although I still wonder whether there is a better solution, which is still efficient.

I’m not sure that I agree with the terminology. When I say “wrong”, I mean that the algorithm (the autograd implementation) does not correctly calculates the real gradient.

Or maybe you refer to “0 * Nan = Nan”? That’s not what I mean when I say it’s wrong. This is totally fine: As you say, this is just what IEEE defines.

I refer to what autograd calculates. It does not calculate the correct gradient. It sometimes (as in this example) calculates NaN even though the gradient is well defined. My assumption was always that autograd always calculates the correct gradient when it is well defined. But this assumption doesn’t seem to be the case because of the behavior as shown here.