Loss function for Floating targets

Hi Shuvayan!

Yes, pytorch’s cross_entropy_loss() is a special case of cross-entropy
that requires integer categorical labels (“hard targets”) for its targets.
(It also takes logits, rather than probabilities, for its predictions.)

It does sound like you want a general cross-entropy loss that takes
probabilities (“soft tagets”) for its targets. This general version is not
built in to pytorch.

But you can implement the general version using pytorch tensor
operations. See this earlier thread:

Note that the softXEnt() implemented in this post also takes logits
for its predictions. If your use case requires you to pass in probabilities
for your predictions (less numerically stable), you will have to modify
softXEnt() accordingly.

Good luck.

K. Frank

1 Like