Multi-class cross entropy loss and softmax in pytorch


In this topic ,ptrblck said that a F.softmax function at dim=1 should be added before the nn.CrossEntropyLoss().
In the document(https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) , it return nll_loss(log_softmax(input, 1) which return negative log and softmax.
My question is that should I calculated the softmax at dim=1 before the function nn.CrossEntropyLoss which already have a softmax at dim=1 ?

No, F.softmax should not be added before nn.CrossEntropyLoss.
I’ll take a look at the thread and edit the answer if possible, as this might be a careless mistake!
Thanks for pointing this out.

EDIT: Indeed the example code had a F.softmax applied on the logits, although not explicitly mentioned.
To sum it up: nn.CrossEntropyLoss applies F.log_softmax and nn.NLLLoss internally on your input, so you should pass the raw logits to it.

7 Likes

What loss function are we supposed to use when we use the F.softmax layer?

Hi Brando!

If you want to use a cross-entropy-like loss function, you shouldn’t
use a softmax layer because of the well-known problem of increased
risk of overflow.

I gave a few words of explanation about this problem in a reply in
another thread:

You should either use nn.CrossEntropyLoss (which takes
pre-softmax logits, rather than post-softmax probabilities)
without a softmax-like layer, or use a nn.LogSoftmax layer,
and feed the results into nn.NLLLoss. (Both of these combine
an implicit softmax with the subsequent log in a way that avoids
the enhanced overflow problem.)

If you are stuck for some reason with your softmax layer, you
should run the probabilities output by softmax through log(),
and then feed the log-probabilities to nn.NLLLoss (but expect
increased risk of overflow).

(I am not aware of any single pytorch cross-entropy loss function
that takes post-softmax probabilities directly.)

Good luck!

K. Frank

2 Likes

Hi, if softmax is not to be used, how do we get the output as probabilities for a multi-class classification problem? I have explained my problem here. Please take a look at my code and help me out since I am a beginner.

You can just apply it to your output as normal. So
model.eval()
output = net(input)
sm = torch.nn.Softmax()
probabilities = sm(output)
print(probabilities )