Need Guidance on Loss Function for Multi-Label/Multi-Output Classification

Hello everyone,

I’m encountering a challenge with the choice of the loss function for a multi-label/multi-output classification problem in PyTorch. I hope you can provide some insights and guidance.

Problem Description: I’m working on a problem where we have 47 labels, and each label can belong to one of three possible classes (0, 1, -1). So essentially, it’s a multi-label classification problem with three classes. I’ve never worked on such a problem before and I’m a bit confused about which loss function to use.

Example Data: Here’s a simplified example of my data:

y: tensor([[[ 1., -1., -1., 1., -1., 1., ... ],
           ...
           [ 0., 1., -1., 0., 1., -1., ... ]],

y_pred: tensor([[ 0.0156, 0.0109, 0.0048, 0.0125, 0.0005, ... ],
                ...
                [ 0.0172, 0.0048, 0.0098, 0.0057, -0.0036, ... ]],

Data Shapes:

  • y shape: torch.Size([8, 47])
  • y_pred shape: torch.Size([8, 47])

Dilemma: I’m aware that the Binary Cross Entropy loss function technically can’t be used here because there are three classes. Cross-entropy seems like the right choice for classification problems like this, but I’m struggling with the fact that I need to compare probabilities (ranging from 0 to 1) with labels (-1, 0, 1).

I’m also used to working with one-hot encoding, but I’m unsure if it’s applicable in this case.

I’d greatly appreciate any advice or suggestions on how to approach this problem. What loss function should I use, and how should I handle the class labels and predictions?

Thank you in advance for your help!

Hi Mario!

The short story: Treat this as what pytorch’s CrossEntropyLoss calls the
“K-dimensional case” (with K = 1), where the trailing “channels” dimension
is of length 47.

Yes, this is an appropriate description. You have 47 three-class classification
problems that all take the same input so you get better throughput and likely
better training if you build one model that performs the 47 classifications at the
same time.

We will group these 47 problems together into a trailing “labels” dimension
that I will call the “channels” dimension.

First, shift the values in your y tensor (the ground-truth target tensor) so that
they run over {0, 1, 2} and will be proper categorical class labels from the
perspective of pytorch’s CrossEntropyLoss.

You will want your y_pred tensor (the output of your model that will be the
input to CrossEntropyLoss) to be (unnormalized) log-probabilities for your
three classes.

(Note that pytorch’s CrossEntropyLoss has started referring to its input
as “unnormalized logits”. This is technically incorrect – they are actually
log-probabilities – but the point being made – that input lives in “log space”
rather than “probability space” – is basically correct.)

I assume that the 8 is your batch dimension. So for each batch element and
each of your 47 “channels” you have a three-class categorical class label (that
takes values in {0, 1, 2}). (Note, the example y you gave above appears to
be a three-dimensional tensor. You do want a two-dimensional tensor of shape
[nBatch = 8, channels = 47].)

Your y_pred should have three (unnormalized) class log-probabilities (which
is to say, one for each of your three classes) for each batch element and channel.
It should therefore be of shape [nBatch = 8, nClass = 3, channels = 47]
and have values that run from -inf to inf.

Correct.

Conceptually you do this by predicting (for each batch element and channel) a
set of three probabilities (one for each class) that sum to one.

In practice you do this by predicting three unnormalized log-probabilities. They
are unnormalized in the sense that when converted to probabilities, they don’t
necessarily sum to one. The conversion to probabilities and normalization (that
is, that they sum to one) occurs, in effect, inside of CrossEntropyLoss.

Probably the easiest way to do this is to have the final layer of your model be
lin = Linear (in_features = in_features, out_features = 3 * 47)
(for your use case of nClass = 3 and channels = 47) and group the classes
together by reshaping the output: output = lin (t).reshape (-1, 3, 47).

Clear your mind of ALL thoughts of “one-hot encoding.”

Best.

K. Frank