Create a Variation of Cross Entropy Loss with Per Pair of Classes Weights

I am after creating a Cross Entropy Loss with the addition ow weighing per pair of classes.

For instance, let’s say I have 5 classes, I would like to have greater penalty for the case the input class is 1 and the output prediction class is 5 then the case the output is 2.

In classic classification it is done by a loss matrix $D$ where $D_{ij}$ is the penalty factor for classifying class i as j.

How can one achieve this in PyTorch using custom loss?
I am asking both in regard to speed optimization and numeric stability.

Hi David!

If I understand your use case, you would not typically do this where
your target (ground-truth labels) are categorical class labels. This
is because with a certain (“hard,” probability-of-one) class labels, cross
entropy only depends on the predicted probability of the correct class
(so there are no “pairs of classes”).

If you insist on using per-pair weights with categorical labels, then only
the diagonal of your per-pair weight matrix, D, will matter and you can
use pytorch’s built-in CrossEntropyLoss with its weight argument:

loss_fn = torch.nn.CrossEntropyLoss (weight = torch.diag (D))

If you have probabilistic (“soft”) labels, then all elements of D will matter
and you can implement per-pair-weighted, probabilistic-label cross entropy
as follow:

# assumes that input and target both have shape [nBatch, nClass]
def perPairWeightedCE (input, target, D):
    logprobs = torch.nn.functional.log_softmax (input, dim = 1)
    return -(target *  (logprobs @ D.T)).sum() / input.shape[0]


K. Frank

1 Like

Hi @KFrank ,
Thank you for the assistance.

What I want is to be able to have different penalty for different mistakes.
Let’s say we work on a sample with class 1.
I want that in case the net labeled the sample as 5 to have a bigger penalty than the case it was labeled as 2.

Is that clearer?

Thank You.

Hi David!

Cross entropy (by definition) doesn’t work this way. But you can write
some other loss function that does have “per-pair” penalties.

I give a couple of possibilities in the example script, below. misclassA()
just weights the probability for an incorrect prediction by the per-pair
weight given in your matrix D.

My guess is that such a loss function won’t train very well. In my experience,
the logarithmic divergence for bad predictions in cross entropy seems to
be very helpful for training.

misclassB() (which I have not tried out on any kind of training) puts in
such a logarithmic divergence. Maybe it will work better.

Lastly, it might make sense to use cross entropy as your “base” loss
criterion to get its generally good training properties and then add on
something like misclassA() or misclassB() to also incorporate your
per-pair penalties.

Here is the script:

import torch
print (torch.__version__)

torch.manual_seed (2022)

# loss is weighted probablity of incorrect prediction
def misclassA (pred, targ, D):   # pred[nBatch, nClass], targ[nBatch], D[nClass, nClass]
    return  (D[targ] * pred).mean()

# give incorrect predictions a logarithmic divergence
# (note, it would be better to work with log-probalities than to use clamp)
def misclassB (pred, targ, D):   # pred[nBatch, nClass], targ[nBatch], D[nClass, nClass]
    return  (D[targ] * (-(pred - 1.0).clamp (min = 1.e-7).log())).mean()

nBatch = 3
nClass = 5

D = torch.rand (nClass, nClass)   # per-pair class weights
D -= D * torch.eye (nClass)       # no penalty for correct predictions

logits = torch.randn (nBatch, nClass)
pred = logits.softmax (dim = 1)
targ = torch.randint (nClass, (nBatch,))

lossA = misclassA (pred, targ, D)
lossB = misclassB (pred, targ, D)

print ('lossA =', lossA)
print ('lossB =', lossB)

lossCE = torch.nn.functional.cross_entropy (logits, targ)   # standard cross entropy for comparison

print ('lossCE =', lossCE)

And here is its output:

lossA = tensor(0.0818)
lossB = tensor(7.3061)
lossCE = tensor(2.8050)


K. Frank

1 Like

@KFrank , What do you think about something like:

The idea is adding weight to the Cross Entropy loss based on the matrix.
Will it be friendly to the AD engine?