Handling None Labels in Dataset Leads to NaN Gradients and Predictions

Hi everyone,

I’ve encountered an issue while training my model with a dataset that occasionally has samples with None labels. To handle these cases, I set the loss to 0 whenever the label is None by using reduction="none" on the loss function. Here’s a simplified version of my approach:

import torch
from torch import optim, nn
from torch.utils.data import DataLoader

# Dummy data
x = torch.randn(100, 10)
y = torch.randn(100, 1)

# Set random values to NaN
nan_indexes = torch.randperm(100)[:10]
y[nan_indexes] = float("nan")

dl = DataLoader(list(zip(x, y)), batch_size=10)

# Model, loss function, and optimizer
model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 1))
criterion = nn.MSELoss(reduction="none")
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 20

for epoch in range(num_epochs):
    for x, y in dl:
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(outputs, y)

        # Handle None labels
        nan_indexes = torch.isnan(y).nonzero()
        loss[nan_indexes] = 0
        loss = loss.sum() / (loss.shape[0] - nan_indexes.shape[0])  # Calculate the correct average loss

        loss.backward()
        optimizer.step()
        print(loss.item())

The issue arises after a few steps: the gradient norms become NaN and loss is 0. Eventually, the model’s predictions also yield NaN and then the training crashes.

Here’s what I’ve tried:

  • Verified that only the None labels are handled with the zero loss.
  • Checked that the data inputs do not contain NaN values.

Despite these checks, the problem persists. Has anyone faced a similar issue or have any suggestions on how to properly handle None labels in the dataset without leading to NaN gradients?

Thanks in advance for your help!

Edit: It seems that if I zero out the label before calculating the output, I don’t get NaN. Is this the correct procedure if I want to ignore samples on the fly?

Hi Lpolisi!

You talk about labels with None values, but, in your example, you have
labels (your y) with nan values. None and nan are different. Following
your example, I will talk about nans, as they are the cause of your issue.

Yes. nans, by design, pollute everything. (For example, 0 * nan = nan.)
So even though you zero out the nans in loss, you still have nans in the
computation graph and they pollute the gradients when you backpropagate.

Yes, this is the correct approach. Get rid of the nans before you perform
any computation that mixes them with a requires_grad = True tensor
and you will avoid having any nans lurking in your computation graph that
would just waiting to pollute your gradients when you backpropagate.

Consider:

>>> import torch
>>> print (torch.__version__)
2.3.1
>>>
>>> t = torch.arange (5.0, requires_grad = True)
>>> y = torch.ones (5)
>>> y[1:3] = float ('nan')
>>>
>>> nan_indices = torch.isnan (y).nonzero()
>>> y
tensor([1., nan, nan, 1., 1.])
>>> nan_indices
tensor([[1],
        [2]])
>>>
>>> loss_unr = (t - y)**2
>>> loss_unr                      # damage is done, need to get rid of nans first
tensor([1., nan, nan, 4., 9.], grad_fn=<PowBackward0>)
>>>
>>> # get rid of nans
>>> y[nan_indices] = 99.0
>>> loss_unr = (t - y)**2         # no nans hiding in computation graph
>>> loss_unr                      # no nans
tensor([1.0000e+00, 9.6040e+03, 9.4090e+03, 4.0000e+00, 9.0000e+00],
       grad_fn=<PowBackward0>)
>>>
>>> loss_unr[nan_indices] = 0.0   # zero out bogus values
>>> loss = loss_unr.sum()
>>> loss                          # no nans
tensor(14., grad_fn=<SumBackward0>)
>>> loss.backward()
>>> t.grad                        # grad is not polluted with nans
tensor([-2., -0., -0.,  4.,  6.])

Best.

K. Frank