# Multi-class cross entropy loss and softmax in pytorch

In this topic ,ptrblck said that a F.softmax function at dim=1 should be added before the nn.CrossEntropyLoss().
In the document(https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) , it return nll_loss(log_softmax(input, 1) which return negative log and softmax.
My question is that should I calculated the softmax at dim=1 before the function nn.CrossEntropyLoss which already have a softmax at dim=1 ?

No, `F.softmax` should not be added before `nn.CrossEntropyLoss`.
I’ll take a look at the thread and edit the answer if possible, as this might be a careless mistake!
Thanks for pointing this out.

EDIT: Indeed the example code had a `F.softmax` applied on the logits, although not explicitly mentioned.
To sum it up: `nn.CrossEntropyLoss` applies `F.log_softmax` and `nn.NLLLoss` internally on your input, so you should pass the raw logits to it.

7 Likes

What loss function are we supposed to use when we use the `F.softmax` layer?

Hi Brando!

If you want to use a cross-entropy-like loss function, you shouldn’t
use a softmax layer because of the well-known problem of increased
risk of overflow.

I gave a few words of explanation about this problem in a reply in

You should either use `nn.CrossEntropyLoss` (which takes
pre-softmax logits, rather than post-softmax probabilities)
without a softmax-like layer, or use a `nn.LogSoftmax` layer,
and feed the results into `nn.NLLLoss`. (Both of these combine
an implicit softmax with the subsequent log in a way that avoids
the enhanced overflow problem.)

If you are stuck for some reason with your softmax layer, you
should run the probabilities output by softmax through `log()`,
and then feed the log-probabilities to `nn.NLLLoss` (but expect
increased risk of overflow).

(I am not aware of any single pytorch cross-entropy loss function
that takes post-softmax probabilities directly.)

Good luck!

K. Frank

2 Likes

Hi, if softmax is not to be used, how do we get the output as probabilities for a multi-class classification problem? I have explained my problem here. Please take a look at my code and help me out since I am a beginner.

You can just apply it to your output as normal. So
model.eval()
output = net(input)
sm = torch.nn.Softmax()
probabilities = sm(output)
print(probabilities )