Training with missing labels


I’m training a net with a multi-task problem, let’s say that the input is a facial image, and the 2 losses are cross entropy on gender and cross entropy age (age having 101 different classes) so that total_loss=loss_age+loss_gender

The problem is that some images have a missing label, eg. an image when only an age is given, or just the gender.
I know that I should set the loss to 0 (or any constant) for a missing label, but I am not sure if I should manually write a new loss/criterion or if there’s an elegant way to write that in pytorch? :slight_smile: How would you go about it?

1 Like

First, you should treat age as a regression problem unless there is a good reason you need to do classification.

Second, you can simply index into the output of your models and select the rows where you have an output, then calculate the loss only on those rows.

Classification works better for age estimation (take a look at any age estimation paper or the recent papers from cvpr).

Also, my original suggestion of setting the loss to 0 or a constant is way way cleaner than your suggestion. My question was is there a more elegant way.

Regarding the underlying mathematics, it is no difference between setting the loss to zero and not calculating/backpropagating it at all, since a loss value of zero means, that there are no mistakes and thus no gradient update is necessary (which is the same result if you don’t calculate a loss for this part at all).

On the implementation side you could use negative values to indicate missing variables and thus simply do something like

pred[target<0] = target[target<0]

and later on just calculate the loss as you did before. This would result in a loss-value of 0 for the considered prediction/data sample as the prediction and the target would be same-valued.


Nice, thanks justusschok :slight_smile:

@justusschock Good point with zero loss not contributing to backprop. In newer versions of PyTorch, nn.CrossEntropyLoss has an ignore_index parameter which applies your logic internally to save a bit of code. For instance, no matter what the pred tensor is, the loss will always be zero if target contains only zeroes

import torch
import torch.nn as nn

criterion = nn.CrossEntropyLoss(ignore_index=0)
pred = torch.randn(3, 5)
target = torch.zeros(3, dtype=torch.long)
criterion(pred, target)