Gradients contaminated by unused inputs

I have the following minimal example:

import torch

torch.manual_seed(137)

L = torch.nn.Linear(8, 1)
x = torch.cat((torch.rand((4, 8)), torch.full((4,8), float('nan'))), dim=0)
x = L(x)[:4].sum()
x.backward()

print(L.weight.grad)

The gradients are NaN, although the output does not depend on the NaN entries.
Is this behaviour expected?
(In more complex architectures the same behaviour can be achieved by setting the unused entries to values that are close to overflowing the 32bit floats.)

Hi Leander!

I would say that it’s expected, although in some cases, such as yours,
maybe not desirable.

The cause is that pytorch computes gradients during the backward
pass by applying the chain rule numerically.

Let’s look at the gradient of the slicing operation:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> a = torch.ones (8, requires_grad = True)
>>> a
tensor([1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)
>>> a[:4].sum().backward()
>>> a.grad
tensor([1., 1., 1., 1., 0., 0., 0., 0.])

That’s reasonable – the elements that were not included in the slice
have zero gradient.

We now backpropagate the gradient of the slicing operation through
the application of the Linear layer, L. Simplifying things a bit, you
are calculation L.weight @ x. The gradient (more precisely, the
Hessian) of this expression (with respect to L.weight) is, in essence,
x. Letting g be the gradient of the slicing operation, the chain rule
gives us, more or less, x @ g.

The problem is by the time g gets to this step in the backpropagation,
autograd no longer knows that some elements were ignored by the
slicing operation – all it knows is that some elements of g happen
to be 0.0. (Perhaps those zeros were calculated numerically, e.g.,
2.0 - 6.0 / 3.0, rather than coming from an “ignore” operation.)

According to the highly desirable rules of floating-point arithmetic,
nan * 0.0 = nan. Autograd doesn’t know that it’s supposed to
ignore those nans in x, rather, it’s supposed to multiply them by 0.0
and thus will get nans as the results.

In a hypothetical world, floating-point numbers could have a special
ignore value, in addition to things like nan and inf. In such a
world, the gradient of the slicing operation could have ignore for
the gradient of the elements not included in the slice. Then, perhaps,
autograd could say that nan * ignore = 0.0 (or something), and
you would get the result you were hoping for. But we don’t live in
that hypothetical world.

Best.

K. Frank

Thank you for the detailed explanation! It makes perfect sense, the result is just a bit counter-intuitive.