Hi David!
Cross entropy (by definition) doesn’t work this way. But you can write
some other loss function that does have “per-pair” penalties.
I give a couple of possibilities in the example script, below. misclassA()
just weights the probability for an incorrect prediction by the per-pair
weight given in your matrix D
.
My guess is that such a loss function won’t train very well. In my experience,
the logarithmic divergence for bad predictions in cross entropy seems to
be very helpful for training.
misclassB()
(which I have not tried out on any kind of training) puts in
such a logarithmic divergence. Maybe it will work better.
Lastly, it might make sense to use cross entropy as your “base” loss
criterion to get its generally good training properties and then add on
something like misclassA()
or misclassB()
to also incorporate your
per-pair penalties.
Here is the script:
import torch
print (torch.__version__)
torch.manual_seed (2022)
# loss is weighted probablity of incorrect prediction
def misclassA (pred, targ, D): # pred[nBatch, nClass], targ[nBatch], D[nClass, nClass]
return (D[targ] * pred).mean()
# give incorrect predictions a logarithmic divergence
# (note, it would be better to work with log-probalities than to use clamp)
def misclassB (pred, targ, D): # pred[nBatch, nClass], targ[nBatch], D[nClass, nClass]
return (D[targ] * (-(pred - 1.0).clamp (min = 1.e-7).log())).mean()
nBatch = 3
nClass = 5
D = torch.rand (nClass, nClass) # per-pair class weights
D -= D * torch.eye (nClass) # no penalty for correct predictions
logits = torch.randn (nBatch, nClass)
pred = logits.softmax (dim = 1)
targ = torch.randint (nClass, (nBatch,))
lossA = misclassA (pred, targ, D)
lossB = misclassB (pred, targ, D)
print ('lossA =', lossA)
print ('lossB =', lossB)
lossCE = torch.nn.functional.cross_entropy (logits, targ) # standard cross entropy for comparison
print ('lossCE =', lossCE)
And here is its output:
1.12.0
lossA = tensor(0.0818)
lossB = tensor(7.3061)
lossCE = tensor(2.8050)
Best.
K. Frank