Let L be the result of a loss function before reduction, possibly with NaN values.
torch.nansum(L) to accumulate non-NaN values from L so that
.backward() would result in valid gradient.
However, my expectation does not match the result.
Here is a snippet to reproduce the problem.
def test(): import torch torch.manual_seed(3) # net, opt net = torch.nn.Linear(3,2) opt = torch.optim.SGD(net.parameters(), lr=0.1) # data (x), target (y) with NaN x = torch.randn(2,3) y = torch.randn(2,2) y = torch.where(y < 0, torch.tensor(float('nan')), y) # NaN at (1,0) # calculate loss out = net(x) loss = torch.nn.MSELoss(reduction='none')(out, y) # accumulate losses except NaN loss = loss.nansum() # backward, train opt.zero_grad() loss.backward() opt.step() # see if any parameter became NaN assert not torch.any(net.weight.isnan()) print('Test passed')
loss contains three valid numbers and one NaN,
loss.nansum() to replace single NaN with 0 so that the following
loss.backward() would calculates gradients without NaN.
However, all the values in
net.weight.grad becomes NaN.
(1) if this is a bug, and
(2) what would be a correct / better way to calculate the loss in my problem.
Edit: added import statement to be stand-alone example.
I tested under Python 3.7.10, PyTorch 1.8.0.