A suitable way to punish mis-classification via nn.CrossEntropyLoss?

Good Afternoon,

I am wondering whether this is a suitable way to approach a problem or if I should consider alternatives (that I am unaware of) as well.

I have a relatively balanced classification problem involving 3 labels: 0, 1, and 2. I am interested, however, in paying a little more attention to getting the 0s right. My approach is, when using nn.CrossEntropyLoss, is to punish mistakes for the 0s more than 1 or 2 by:

weights = [2.0, 1.0, 1.0]
class_weights = torch.FloatTensor(weights).cuda()
criterion = nn.CrossEntropyLoss(weight=class_weights)

Is this is a reasonable strategy or are there other things that I might consider?

Thanks in advance for your time and consideration!

Hi Andrew!

Yes, this is a reasonable approach to more accurately classify
class-0 samples.

That is, adding these class weights to your loss function will train
your network to more often correctly label class-0 samples as
class-0 (fewer false negatives), but at the cost of more frequently
incorrectly labelling class-1 and class-2 samples as class-0 (more
false positives from the perspective of class-0).

Just to emphasize what I alluded to above, this depends on what
you mean by “getting the 0s right” and “punish mistakes for the 0s”.

If you consider mislabelling a class-1 as a class-0 (a false positive)
as not “getting the 0s right” and a “mistake for the 0s,” then this
likely won’t accomplish your goal. But if you’re only counting
mislabelling an actual class-0 (a false negative) as a mistake, this
is the way to go.

Best.

K. Frank

Thanks so much, this makes a lot of sense and this is precisely what I want:

adding these class weights to your loss function will train
your network to more often correctly label class-0 samples as
class-0 (fewer false negatives)

Out of curiosity, do you mind elaborating on this problem and what one might do?

Just to emphasize what I alluded to above, this depends on what
you mean by “getting the 0s right” and “punish mistakes for the 0s”.

If you consider mislabelling a class-1 as a class-0 (a false positive)
as not “getting the 0s right” and a “mistake for the 0s,” then this
likely won’t accomplish your goal.

Hi Andrew!

Your network does what your loss function trains it to do.

As a general rule, if you train your network to do better on one thing,
then – all else being equal – it will be likely to do worse on something
else.

(That doesn’t mean you can’t train your network to do better on
everything, perhaps by training longer, or using a more apt loss
function, or using a better optimization algorithm, or training with
more or better data, etc.)

My point is that using class weights that favor class-0, you’re telling
your network training that you care less about getting classes 1 and
2 right, so, in general, your network won’t perform as well on classes
1 and 2, including mislabelling them as class-0 – because that’s what
you trained it to do.

Now, if you don’t really care about mixing up classes 1 and 2, but
want to get class-0 right from both a false negative and false positive
perspective (that is, you’re willing, e.g., to mislabel class-1 as class-2,
but you don’t want to mislabel class-1 as class-0), you could, at the
extreme, train a binary classifier that identifies class-0 vs. everything
else. (Again, a trade-off: do better on both class-0 false positives and
false negatives, at the cost of not distinguishing class 1 from class-2.)

Now, some speculation, because I’ve never actually tried this. You
could add to your conventional three-class loss function (class-0 vs.
class-1 vs. class-2) that does distinguish between class-1 and class-2,
a two-class loss function (class-0 vs other-than-class-0). This will
help train your network to do better on both class-0 false negatives
and false positives, while still somewhat distinguishing class-1 and
class-2, but at the cost of not distinguishing them as well.

This will also bias your network to do better on class-0, but in a
way that’s different than overweighting class-0 in your conventional
three-class loss function. You’d still be making a trade-off, just a
different one. Which way to go depends on the details of what’s
more important to you.

Best.

K. Frank

1 Like

Wow! Thank you so much for the insightful and easy to understand post. Your comment below actually sparked a memory from something that I read not too long ago, a journal article I believe, which I think did just that. I am going to look for it and see if I can apply the same for fun/learning.

Now, some speculation, because I’ve never actually tried this. You
could add to your conventional three-class loss function (class-0 vs.
class-1 vs. class-2) that does distinguish between class-1 and class-2,
a two-class loss function (class-0 vs other-than-class-0). This will
help train your network to do better on both class-0 false negatives
and false positives, while still somewhat distinguishing class-1 and
class-2, but at the cost of not distinguishing them as well.

If interested, I can post it here for you to skim – but otherwise, this has been quite the learning experience for me. Thank you again!

Edit: I found it. It was a Stanford ML final project paper. The point is located on the end of page 3: https://github.com/JRC1995/BERT-Disaster-Classification-Capsule-Routing/blob/master/Project_Report.pdf. I think its similar to your point.