In this topic ,ptrblck said that a F.softmax function at dim=1 should be added before the nn.CrossEntropyLoss().
In the document(https://pytorch.org/docs/stable/nn.html?highlight=crossentropy#torch.nn.CrossEntropyLoss) , it return nll_loss(log_softmax(input, 1) which return negative log and softmax.
My question is that should I calculated the softmax at dim=1 before the function nn.CrossEntropyLoss which already have a softmax at dim=1 ?
F.softmax should not be added before
I’ll take a look at the thread and edit the answer if possible, as this might be a careless mistake!
Thanks for pointing this out.
EDIT: Indeed the example code had a
F.softmax applied on the logits, although not explicitly mentioned.
To sum it up:
nn.NLLLoss internally on your input, so you should pass the raw logits to it.
What loss function are we supposed to use when we use the
If you want to use a cross-entropy-like loss function, you shouldn’t
use a softmax layer because of the well-known problem of increased
risk of overflow.
I gave a few words of explanation about this problem in a reply in
You should either use
nn.CrossEntropyLoss (which takes
pre-softmax logits, rather than post-softmax probabilities)
without a softmax-like layer, or use a
and feed the results into
nn.NLLLoss. (Both of these combine
an implicit softmax with the subsequent log in a way that avoids
the enhanced overflow problem.)
If you are stuck for some reason with your softmax layer, you
should run the probabilities output by softmax through
and then feed the log-probabilities to
nn.NLLLoss (but expect
increased risk of overflow).
(I am not aware of any single pytorch cross-entropy loss function
that takes post-softmax probabilities directly.)
Hi, if softmax is not to be used, how do we get the output as probabilities for a multi-class classification problem? I have explained my problem here. Please take a look at my code and help me out since I am a beginner.
You can just apply it to your output as normal. So
output = net(input)
sm = torch.nn.Softmax()
probabilities = sm(output)