KL Divergence for Multi-Label Classification

I need to use KL Divergence as my loss for a multi-label classification problem with 5 classes (Eqn. 6 of this paper). I have soft ground truth targets from a teacher network of the form [0.99, 0.01, 0.99, 0.1, 0.1] for each sample (since its a multi-label problem, a sample can belong to multiple classes), and predictions which are five probabilities that don’t sum up to 1.

How can the KL divergence be computed? It seems like nn.KLDivLoss() expects the probabilities to add up to 1 – otherwise, I get a negative loss.

Hi Arjung!

The short answer is use BCEWithLogitsLoss.

Several comments:

KL divergence and cross entropy are closely related. For fixed
targets, KL divergence and cross entropy differ by a constant that
is independent of your predictions (so it doesn’t affect training).

You should understand your multi-label, 5-class task to be a set of
5 binary classification tasks (that share the same network). That is,
class-0 can be “yes” or “no”, class-1 can be “yes” or “no” (independent
of the value for class-0), and so on.

So, whether you call it KL divergence or cross entropy, you want to
calculate the loss for each of your 5 binary problems separately, and
then add (or average) them together. You could certainly do this
with KLDivLoss (with a loop), but KLDivLoss isn’t set up to do this
automatically. BCEWithLogitsLoss is.

For BCEWithLogitsLoss you want your predictions (the output of your
model) to have shape [nBatch, nClass = 5] and be raw-score logits,
rather than probabilities. (So don’t pass your logit predictions through
softmax() to convert them to probabilities; softmax() is, in effect,
built into BCEWithLogitsLoss.) Your targets should also have shape
[nBatch, nClass = 5] and should be the probabilities of each of
your samples being (independently) in each of your 5 classes.

(And to confirm, BCEWithLogitsLoss does accept “soft” targets that
are probabilities between 0.0 and 1.0, rather than requiring “hard”
0-1 targets.)

Best.

K. Frank

2 Likes

Thanks a lot for such a detailed response! :slight_smile: