Hi S.!
The short answer is that you have to write your own cross-entropy
function to do what you want – see below.
There are two things going on here:
First, as Aman noted, the input
to CrossEntropyLoss
(your
softmax_out1
) should be raw-score logits that range from -inf
to
+inf
, rather than probabilities that range from 0.0
to 1.0
. So you
want to pass logits in as the input
without converting them to
probabilities by running them through softmax()
.
Second, CrossEntropyLoss
expects its target
(your softmax_out2
)
to be integer class labels (with shape [nBatch]
, rather than
[nBatch, nClass]
). So CategoricalCrossEntropyWithLogitsLoss
might be a better (if lengthier) name for this loss function.
Now, how to do what you want:
Even if you write your own cross-entropy loss function, you do not
want to pass in probabilities for your input
as doing so will be less
numerically stable than passing in logits.
It does, however, make sense to use probabilities (rather than integer
class labels) for your target
. (These are sometimes called soft labels
or soft targets.) It’s just that pytorch doesn’t offer such a version of
cross entropy.
The following post shows how to implement such a “soft cross-entropy”
loss. It takes logits for its input
(for numerical stability) and takes
probabilities for its “soft-label” target
:
Best.
K. Frank