You should pass raw logits to nn.CrossEntropyLoss
, since the function itself applies F.log_softmax
and nn.NLLLoss()
on the input.
If you pass log probabilities (from nn.LogSoftmax
) or probabilities (from nn.Softmax()
) your loss function won’t work as intended.
6 Likes