Training with missing labels

Hi!

I’m training a net with a multi-task problem, let’s say that the input is a facial image, and the 2 losses are cross entropy on gender and cross entropy age (age having 101 different classes) so that total_loss=loss_age+loss_gender

The problem is that some images have a missing label, eg. an image when only an age is given, or just the gender.
I know that I should set the loss to 0 (or any constant) for a missing label, but I am not sure if I should manually write a new loss/criterion or if there’s an elegant way to write that in pytorch? How would you go about it?

1 Like

First, you should treat age as a regression problem unless there is a good reason you need to do classification.

Second, you can simply index into the output of your models and select the rows where you have an output, then calculate the loss only on those rows.

Classification works better for age estimation (take a look at any age estimation paper or the recent papers from cvpr).

Also, my original suggestion of setting the loss to 0 or a constant is way way cleaner than your suggestion. My question was is there a more elegant way.

Regarding the underlying mathematics, it is no difference between setting the loss to zero and not calculating/backpropagating it at all, since a loss value of zero means, that there are no mistakes and thus no gradient update is necessary (which is the same result if you don’t calculate a loss for this part at all).

On the implementation side you could use negative values to indicate missing variables and thus simply do something like

``````pred[target<0] = target[target<0]
``````

and later on just calculate the loss as you did before. This would result in a loss-value of 0 for the considered prediction/data sample as the prediction and the target would be same-valued.

3 Likes

Nice, thanks justusschok

@justusschock Good point with zero loss not contributing to backprop. In newer versions of PyTorch, `nn.CrossEntropyLoss` has an `ignore_index` parameter which applies your logic internally to save a bit of code. For instance, no matter what the `pred` tensor is, the loss will always be zero if `target` contains only zeroes

``````import torch
import torch.nn as nn

criterion = nn.CrossEntropyLoss(ignore_index=0)
pred = torch.randn(3, 5)
target = torch.zeros(3, dtype=torch.long)
criterion(pred, target)
``````