Using `BCEWithLogisLoss` for multi-label classification

Question

The key difference of nn.CrossEntropyLoss() and nn.BCEWithLogitsLoss() is the former uses Softmax while the latter uses multiple Sigmoid when computing loss. I have confusion about this

  • Does using nn.BCEWithLogitsLoss() and setting a threshold (say 0.5) implicitly assume we are doing multi-label classification?
  • If it is so. Then if I wrongly use nn.CrossEntropyLoss() when I should have used nn.CrossEntropyLoss(). If discard threshold 0.5 and choose class that has highest probability, then the performance is still very good. But how this is equivalent to using nn.CrossEntropyLoss()?
2 Likes

Hi Mr. Robot!

While true, this is hardly the key difference between the two.

Let me first clear up a potential point of confusion:

“Multi-class” classification means that a given sample is in precisely
one class. (One of your classes can be “background” or “no class”
or “unclassified” if this fits your workflow.) (Binary classification means
that you have two classes, e.g., “yes” and “no” or “0” and “1”.)

“Multi-label” classification means that each sample can be in any
number of the specified classes, including zero. So multi-label
classification can be understood as a series of binary classifications:
Is sample 1 in class A – yes or no? Is sample 1 in class B – yes or no?
And so on.

You can’t use CrossEntropyLoss to do multi-label classification.
(It’s just the ticket, though, for multi-class classification.)

You can use BCEWithLogitsLoss for multi-label classification
(as well as for ordinary, single-label binary classification).

Could you clarify whether you are doing multi-class or multi-label
classification, and then follow up with any further questions you have?

Good Luck.

K. Frank

9 Likes

Thank you for your explanation.

I am using MNIST data for multi-class classification (there are ten classes, 0, 1 through 9).

I could use both nn.CrossEntropyLoss() and nn.BCEWithLogtisLoss() on MNIST and get good performance. However, before reading your reply, I thought nn.BCEWithLogitsLoss() could only be using in the multi-label classification (you said it can be used in this setting).

Then here comes two questions

  • Focusing on multi-class classification. If using two losses can both achieve good performance, I suspect they should be similar in some sense but they actually use Softmax and multiple Sigmoid respectively, which I deem very different. For example, for Softmax, it makes sense to choose the class that has highest probability. However, for Sigmoid, it likely that two or more classes have very close output (say 0.79, 0.8, 0.81), then choosing one particular class does not seem to make sense and this is main confusion for using nn.BCEWithLogitsLoss() for multi-class classification.
  • Focusing on multi-label classification. Set nn.CrossEntropyLoss() aside and we now focus on nn.BCEWithLogitsLoss(). If I am building an image recognition system that is trying to identify whether there are car, people and stop sign in the given image at the same time. Now it is clear this is a multi-label classification problem (with three labels). However, people generally set 0.5 as a decision threshold for 1/0 classification. So a output like [0.2, 0.3, 0.4] will all be regarded as [0, 0, 0] by hard thresholding. But I believe this is similar to Logistic regression and the decision threshold should be tuned and using hard thresholding does not make sense. Now comes the main question. How do I tune three thresholds at the same time or maybe I just output probability (like many object detection algorithms do).

Hello Mr. Robot!

I’ve haven’t done the experiment, but I very much expect that if
you train a network on reasonably clean multi-class data (say,
MNIST digits) with both CrossEntropyLoss and
BCEWithLogitsLoss you will get better performance using
CrossEntropyLoss.

It’s worth saying a few words about why:

When you train with BCEWithLogitsLoss (with ten outputs
from your network) you really are training a multi-label,
multi-class classifier. You happen to be training it with samples
that always correspond to exactly one class. If your training
data is representative and reasonably clean, your network will
learn that for the kinds samples you give it, the right answer is
exactly one label. Then when applied to a test sample (assuming
that your training and test data are similar in character) you
should get predictions that have a high probability for exactly
one of the classes, and low probabilities for the others.

But your network had to learn that you were giving it – at
least with the training set you were using – a single-label
(but multi-class) problem.

On the other hand, if you use CrossEntropyLoss you are,
in essence, telling your network that this is a single-label
problem, so your network doesn’t have to learn this. You’ve
made its job easier by telling it this, so you should expect
better performance.

Well, as a semantic quibble, you could consider ordinary binary
classification to be multi-label in the sense that you have one class
(say, “cat”) and zero (nothing, so “no cat”) or one (your only class
is “cat”, so “yes cat”) labels. You can’t have more than one label
because you only have one class. That is, you can have none, any,
or all of your labels (in this case only one – “cat”) turned off or on.
So this is a multi-label problem.

I think you’ve identified the main point.

First, a technical note: In both cases the output of your network
will be “raw scores” (logits). They, in effect, get converted to
probabilities internally to the loss functions. I will speak in terms
of probabilities.

CrossEntropyLoss internally uses (in effect) Softmax to convert
the output of your network in to probabilities for which single class
is being predicted. These probabilities sum to one.

BCEWithLogitsLoss internally uses (in effect) Sigmoid to convert
the output of your network into probabilities for each of your classes
individually being active. These probabilities individually range
from zero to one (and sum anywhere from zero to your number
of classes, in the MNIST example, ten).

So, to use your example of probabilities (0.79, 0.8, 0.81),
your network is saying that class A, B, and C all have a probability
of about 80% of being active, so you should logically predict that
your image contains an “A” and a “B” and a “C”. (You shouldn’t
say that your image only contains a “C”, but it’s a kind of squiggly
“C” that looks rather like an “A” and a “B”.)

Of course, if your problem really is single-label (multi-class),
you can’t have “A”, “B”, and “C” at the same time. This is why
BCEWithLogitsLoss (with its internal Sigmoid) isn’t a good
fit for the single-label problem.

Now with CrossEntropyLoss (with its internal Softmax) you
will never get two classes both with greater than 50% probability.
You could get (0.33, 0.33, 0.34, 0.0, 0.0, ...), in
which case you would predict class “C”, but with low certainty,
because it’s a squiggly “C” that looks a lot like an “A” and a “B”.

And this is reasonable. Your network has only a 40% probability
of there being a stop sign in the image. Your best guess is that
there isn’t a stop sign (but you wouldn’t be the farm on it).

So perhaps your network is saying that the bicycle in the image
looks a little bit like a car (20%), the bags of garbage looks
somewhat like people (30%), and the restaurant logo looks
rather like a stop sign (40%). But your best prediction is that
none of your classes appear in the image.

Well, as I argued above, the hard, a priori thresholding does make
sense.

Nonetheless, it is plausible that you could develop some tuning
scheme for your post-network thresholds that would give you
better accuracy. But I think such tuning would likely be highly
dependent on your particular problem and network.

I don’t know of any good post-network threshold-tuning scheme,
but it’s believable that you could develop one.

Whether you output probabilities or thresholded “yes / no”
predictions really depends on your use case – on how you
use the output of your classifier. There are certainly many
use cases where outputting probabilities is appropriate.

Best.

K. Frank

6 Likes

It is really nice to see such a detailed and convincing reply! Thank you very much!