I have a pytorch tensor with NaN inside, when I calculate the loss function using a simple MSE Loss the gradient becomes NaN even if I mask out the NaN values.
Weirdly this happens only when the mask is applyied after calculating the loss and only when the loss has a pow operation inside. The various cases follow
import torch
torch.autograd.set_detect_anomaly(True)
x = torch.rand(10, 10)
y = torch.rand(10, 10)
w = torch.rand(10, 10, requires_grad=True)
y[y > 0.5] = torch.nan
o = w @ x
l = (y - o)**2
l = l[~y.isnan()]
try:
l.mean().backward(retain_graph=True)
except RuntimeError:
print('(y-o)**2 caused nan gradient')
l = (y - o)
l = l[~y.isnan()]
try:
l.mean().backward(retain_graph=True)
except RuntimeError():
pass
else:
print('y-o does not cause nan gradient')
l = (y[~y.isnan()] - o[~y.isnan()])**2
l.mean().backward()
print('masking before pow does not propagate nan gradient')
What makes NaN gradients propagate when passing through the backward pass of the pow function?
Here you create a new tensor that you also call l. This new
tensor does not contain nans. But the new l depends on the
old l which is saved in the computation graph and still contains nans.
Because both the new l and the old l depend on the leaf
variable w that carries requires_grad = True, you have to
backpropagate through the old l. The gradient of the old l
with respect to y is 2 * y which contains nans for those
elements of y that are nan. So you hit the error.
Again, you backpropagate through new l, old l, and y. But
now the gradient of old l with respect to y is just 1, so the nans in y don’t enter into the backpropagation.
This time y[~y.isnan()] is a new tensor (that you haven’t given
a name to) that does not contain nans. Likewise, o[~y.isnan()]
contains no nans.The gradient of l with respect to y[~y.isnan()]
is 2 * y[~y.isnan()] which doesn’t contain nans, so there is no
error.