# Torch.nansum yields NaN gradients unexpectedly

Hi,

Let L be the result of a loss function before reduction, possibly with NaN values.
I expected `torch.nansum(L)` to accumulate non-NaN values from L so that `.backward()` would result in valid gradient.

However, my expectation does not match the result.

Here is a snippet to reproduce the problem.

``````def test():
import torch

torch.manual_seed(3)

# net, opt
net = torch.nn.Linear(3,2)
opt = torch.optim.SGD(net.parameters(), lr=0.1)

# data (x), target (y) with NaN
x = torch.randn(2,3)
y = torch.randn(2,2)
y = torch.where(y < 0, torch.tensor(float('nan')), y) # NaN at (1,0)

# calculate loss
out = net(x)
loss = torch.nn.MSELoss(reduction='none')(out, y)

# accumulate losses except NaN
loss = loss.nansum()

# backward, train
loss.backward()
opt.step()

# see if any parameter became NaN
assert not torch.any(net.weight.isnan())

print('Test passed')
``````

As `loss` contains three valid numbers and one NaN,
I expected `loss.nansum()` to replace single NaN with 0 so that the following `loss.backward()` would calculates gradients without NaN.

However, all the values in `net.weight.grad` becomes NaN.

I wonder:
(1) if this is a bug, and
(2) what would be a correct / better way to calculate the loss in my problem.

Edit: added import statement to be stand-alone example.
I tested under Python 3.7.10, PyTorch 1.8.0.

Hi,

You can use `tensor.register_hook(your_fn)` to be able to print the gradient for each given Tensor.
In particular, here, I think that the loss before the nansum won’t have any nan as you expect.
The problem is that the backward of the mse will create nan values because some of the y has nan. You will need to filter these out before computing the MSE to avoid nan in the gradients flowing back.

Hi Alban,

Thank you so much for your help.
My understanding now is that because MSELoss will use its activation for back propagation, once NaN appeared then it will be accounted for regardless of whatever operation I use afterwards.

I used the following code instead to resolve the issue.

``````    # loss = torch.nn.MSELoss(reduction='none')(out, y)
nan = torch.isnan(y)
y = torch.where(nan, torch.tensor(0.0), y)
out = torch.where(nan, torch.tensor(0.0), out)
loss = (out - y) ** 2
``````

Thanks,
Seongmin

1 Like