Torch.nansum yields NaN gradients unexpectedly


Let L be the result of a loss function before reduction, possibly with NaN values.
I expected torch.nansum(L) to accumulate non-NaN values from L so that .backward() would result in valid gradient.

However, my expectation does not match the result.

Here is a snippet to reproduce the problem.

def test():
    import torch


    # net, opt
    net = torch.nn.Linear(3,2)
    opt = torch.optim.SGD(net.parameters(), lr=0.1)

    # data (x), target (y) with NaN
    x = torch.randn(2,3)
    y = torch.randn(2,2)
    y = torch.where(y < 0, torch.tensor(float('nan')), y) # NaN at (1,0)

    # calculate loss
    out = net(x)
    loss = torch.nn.MSELoss(reduction='none')(out, y)
    # accumulate losses except NaN
    loss = loss.nansum()

    # backward, train

    # see if any parameter became NaN
    assert not torch.any(net.weight.isnan())

    print('Test passed')

As loss contains three valid numbers and one NaN,
I expected loss.nansum() to replace single NaN with 0 so that the following loss.backward() would calculates gradients without NaN.

However, all the values in net.weight.grad becomes NaN.

I wonder:
(1) if this is a bug, and
(2) what would be a correct / better way to calculate the loss in my problem.

Edit: added import statement to be stand-alone example.
I tested under Python 3.7.10, PyTorch 1.8.0.


You can use tensor.register_hook(your_fn) to be able to print the gradient for each given Tensor.
In particular, here, I think that the loss before the nansum won’t have any nan as you expect.
The problem is that the backward of the mse will create nan values because some of the y has nan. You will need to filter these out before computing the MSE to avoid nan in the gradients flowing back.

Hi Alban,

Thank you so much for your help.
My understanding now is that because MSELoss will use its activation for back propagation, once NaN appeared then it will be accounted for regardless of whatever operation I use afterwards.

I used the following code instead to resolve the issue.

    # loss = torch.nn.MSELoss(reduction='none')(out, y)
    nan = torch.isnan(y)
    y = torch.where(nan, torch.tensor(0.0), y)
    out = torch.where(nan, torch.tensor(0.0), out)
    loss = (out - y) ** 2


1 Like