Where to use softmax?

I was working on implementing a RESNET50 model and was wondering if i should use nn.Softmax() inside the model or outside (in training loop).
I can’t see why it would make a difference, tho with a few implementation i’ve seen people use it outside


Hi Aryaman!

I won’t speak directly to ResNet-50, but I assume that your use case
is image classification. It’s not the only reasonable choice, but one
would typically train such a classifier with CrossEntropyLoss.

In pytorch (for good reason), CrossEntropyLoss expects predictions
that are raw-score logits (rather than, say, probabilities). For this
reason, in your training loop, you want Softmax (which converts logits
to probabilities) neither inside nor outside of your model. You would
typically feed the output of the final Linear layer of your classifier
(understood to be logits) directly to CrossEntropyLoss.

If you happen to need the probabilities (for some other reason than
computing the loss), you would pass the logits through Softmax,
but it’s often the case that you don’t need explicit probabilities and
can perform any additional desired processing with the logits directly.

So: No Softmax for training (neither inside nor outside of your model),
and probably no Softmax at all.

(As a side note, internal to pytorch’s CrossEntropyLoss is a
log_softmax() computation, so a conversion to probabilities is
being carried out under the hood, but in log-probability space.)


K. Frank

1 Like