Hi,
Let L be the result of a loss function before reduction, possibly with NaN values.
I expected torch.nansum(L)
to accumulate non-NaN values from L so that .backward()
would result in valid gradient.
However, my expectation does not match the result.
Here is a snippet to reproduce the problem.
def test():
import torch
torch.manual_seed(3)
# net, opt
net = torch.nn.Linear(3,2)
opt = torch.optim.SGD(net.parameters(), lr=0.1)
# data (x), target (y) with NaN
x = torch.randn(2,3)
y = torch.randn(2,2)
y = torch.where(y < 0, torch.tensor(float('nan')), y) # NaN at (1,0)
# calculate loss
out = net(x)
loss = torch.nn.MSELoss(reduction='none')(out, y)
# accumulate losses except NaN
loss = loss.nansum()
# backward, train
opt.zero_grad()
loss.backward()
opt.step()
# see if any parameter became NaN
assert not torch.any(net.weight.isnan())
print('Test passed')
As loss
contains three valid numbers and one NaN,
I expected loss.nansum()
to replace single NaN with 0 so that the following loss.backward()
would calculates gradients without NaN.
However, all the values in net.weight.grad
becomes NaN.
I wonder:
(1) if this is a bug, and
(2) what would be a correct / better way to calculate the loss in my problem.
Edit: added import statement to be stand-alone example.
I tested under Python 3.7.10, PyTorch 1.8.0.