NaNs in input data breaking gradients

I have a usecase for which I want some components of my input data to be NaN (slate data with variable number of items per slate, I want to insert NaNs instead of item features for empty item slots). The NaN propagate through the forward pass properly and I ignore them when calculating the loss (by using the nanmean function from

The problem is that I am now getting NaN gradients even when NaN outputs are not used as inputs to the loss function. Minimal example:

lin_layer = nn.Linear(1, 2)
X_ones = torch.ones(2, 1)
X_ones[0, :] = float('NaN') # [[nan], [1]]

# Approach 1: remove NaN data from input to the linear layer.
# Gradients calculated succesfully

X_input = X_ones[1,:] # select non-NaN inputs
output = lin_layer(X_input)
loss = output.mean()
print(f'loss={loss} (mean of {output})')

# Approach 2: remove NaN data from input to the loss function.
# Gradient calculation fails

output = lin_layer(X_ones)
loss = output[1,:].mean() # select non-NaN outputs
print(f'loss={loss} (mean of {output[1,:]})')


loss=-0.25358104705810547 (mean of tensor([-0.2736, -0.2336], grad_fn=<AddBackward0>))
loss=-0.25358104705810547 (mean of tensor([-0.2736, -0.2336], grad_fn=<SliceBackward>))
1 Like

Unfortunately, any nan will create nan for any number it touches. So they have a tendancy to propagate. And this is the expected behavior here.
You definitely want to perform the masking before using them in any computations as much as possible.


Thanks for the reply! In Approach 2, the loss depends only on output[1,:], the elements of which are non-NaN, therefore the loss is non-NaN as well. Wouldn’t we expect the gradients to propagate properly in this case since the NaN tensor elements are “irrelevant” - they don’t feed into the loss function?

Or is it the case that if at least one component of a tensor is NaN, then the gradients can’t flow through this tensor?

Unfortunately pytorch cannot know about the “irrelevant” part and only uses the chain rule. And in the chaine rule, you get 0 * nan = nan.
There are a few issues on github that discuss this problem: and for example.