A suitable way to punish mis-classification via nn.CrossEntropyLoss?

Good Afternoon,

I am wondering whether this is a suitable way to approach a problem or if I should consider alternatives (that I am unaware of) as well.

I have a relatively balanced classification problem involving 3 labels: 0, 1, and 2. I am interested, however, in paying a little more attention to getting the 0s right. My approach is, when using nn.CrossEntropyLoss, is to punish mistakes for the 0s more than 1 or 2 by:

weights = [2.0, 1.0, 1.0]
class_weights = torch.FloatTensor(weights).cuda()
criterion = nn.CrossEntropyLoss(weight=class_weights)

Is this is a reasonable strategy or are there other things that I might consider?

Thanks in advance for your time and consideration!

Hi Andrew!

Yes, this is a reasonable approach to more accurately classify
class-0 samples.

That is, adding these class weights to your loss function will train
your network to more often correctly label class-0 samples as
class-0 (fewer false negatives), but at the cost of more frequently
incorrectly labelling class-1 and class-2 samples as class-0 (more
false positives from the perspective of class-0).

Just to emphasize what I alluded to above, this depends on what
you mean by “getting the 0s right” and “punish mistakes for the 0s”.

If you consider mislabelling a class-1 as a class-0 (a false positive)
as not “getting the 0s right” and a “mistake for the 0s,” then this
likely won’t accomplish your goal. But if you’re only counting
mislabelling an actual class-0 (a false negative) as a mistake, this
is the way to go.

Best.

K. Frank

Thanks so much, this makes a lot of sense and this is precisely what I want:

adding these class weights to your loss function will train
your network to more often correctly label class-0 samples as
class-0 (fewer false negatives)

Out of curiosity, do you mind elaborating on this problem and what one might do?

Just to emphasize what I alluded to above, this depends on what
you mean by “getting the 0s right” and “punish mistakes for the 0s”.

If you consider mislabelling a class-1 as a class-0 (a false positive)
as not “getting the 0s right” and a “mistake for the 0s,” then this
likely won’t accomplish your goal.

Hi Andrew!

Your network does what your loss function trains it to do.

As a general rule, if you train your network to do better on one thing,
then – all else being equal – it will be likely to do worse on something
else.

(That doesn’t mean you can’t train your network to do better on
everything, perhaps by training longer, or using a more apt loss
function, or using a better optimization algorithm, or training with
more or better data, etc.)

My point is that using class weights that favor class-0, you’re telling
your network training that you care less about getting classes 1 and
2 right, so, in general, your network won’t perform as well on classes
1 and 2, including mislabelling them as class-0 – because that’s what
you trained it to do.

Now, if you don’t really care about mixing up classes 1 and 2, but
want to get class-0 right from both a false negative and false positive
perspective (that is, you’re willing, e.g., to mislabel class-1 as class-2,
but you don’t want to mislabel class-1 as class-0), you could, at the
extreme, train a binary classifier that identifies class-0 vs. everything
else. (Again, a trade-off: do better on both class-0 false positives and
false negatives, at the cost of not distinguishing class 1 from class-2.)

Now, some speculation, because I’ve never actually tried this. You
could add to your conventional three-class loss function (class-0 vs.
class-1 vs. class-2) that does distinguish between class-1 and class-2,
a two-class loss function (class-0 vs other-than-class-0). This will
help train your network to do better on both class-0 false negatives
and false positives, while still somewhat distinguishing class-1 and
class-2, but at the cost of not distinguishing them as well.

This will also bias your network to do better on class-0, but in a
way that’s different than overweighting class-0 in your conventional
three-class loss function. You’d still be making a trade-off, just a
different one. Which way to go depends on the details of what’s
more important to you.

Best.

K. Frank

Wow! Thank you so much for the insightful and easy to understand post. Your comment below actually sparked a memory from something that I read not too long ago, a journal article I believe, which I think did just that. I am going to look for it and see if I can apply the same for fun/learning.

Now, some speculation, because I’ve never actually tried this. You
could add to your conventional three-class loss function (class-0 vs.
class-1 vs. class-2) that does distinguish between class-1 and class-2,
a two-class loss function (class-0 vs other-than-class-0). This will
help train your network to do better on both class-0 false negatives
and false positives, while still somewhat distinguishing class-1 and
class-2, but at the cost of not distinguishing them as well.

If interested, I can post it here for you to skim – but otherwise, this has been quite the learning experience for me. Thank you again!

Edit: I found it. It was a Stanford ML final project paper. The point is located on the end of page 3: https://github.com/JRC1995/BERT-Disaster-Classification-Capsule-Routing/blob/master/Project_Report.pdf. I think its similar to your point.