I have a usecase for which I want some components of my input data to be NaN (slate data with variable number of items per slate, I want to insert NaNs instead of item features for empty item slots). The NaN propagate through the forward pass properly and I ignore them when calculating the loss (by using the nanmean function from https://github.com/pytorch/pytorch/issues/21987)
The problem is that I am now getting NaN gradients even when NaN outputs are not used as inputs to the loss function. Minimal example:
lin_layer = nn.Linear(1, 2)
X_ones = torch.ones(2, 1)
X_ones[0, :] = float('NaN') # [[nan], [1]]
# Approach 1: remove NaN data from input to the linear layer.
# Gradients calculated succesfully
X_input = X_ones[1,:] # select non-NaN inputs
output = lin_layer(X_input)
loss = output.mean()
print(f'loss={loss} (mean of {output})')
lin_layer.zero_grad()
loss.backward()
print(f'grad={lin_layer.weight.grad}')
# Approach 2: remove NaN data from input to the loss function.
# Gradient calculation fails
output = lin_layer(X_ones)
loss = output[1,:].mean() # select non-NaN outputs
print(f'loss={loss} (mean of {output[1,:]})')
lin_layer.zero_grad()
loss.backward()
print(f'grad={lin_layer.weight.grad}')
Output:
loss=-0.25358104705810547 (mean of tensor([-0.2736, -0.2336], grad_fn=<AddBackward0>))
grad=tensor([[0.5000],
[0.5000]])
loss=-0.25358104705810547 (mean of tensor([-0.2736, -0.2336], grad_fn=<SliceBackward>))
grad=tensor([[nan],
[nan]])