nn.Model best practices: should it output logits or probabilities?

When using nn.Model, what is best practice (or what is commonly used) between outputting the logits or the probabilities?

Consider these two simple cases:

1. the model outputs the logits:

class Network(nn.Model):
    def forward(self, x):
        ...
        logits = self.last_layer(x)
        return logits

# Training
data, target = ...
net = Network()
logits = net(data)
loss = nn.functional.some_loss_with_logits(logits, target)
...

# Predicting
data = ...
logits = net(data)
probabilities = nn.functional.some_squashing_function(logits)

Pros:

  • the training code is clean
  • using a loss function with logits is straightforward

Cons:

  • during inference, one has to remember that the network output are logits and not probabilities and apply whatever sigmoid/softmax/other is required

2. the model outputs the probabilities:

class Network(nn.Model):
    def logits(self, x):
        ...
        logits = self.last_layer(x)
        return logits

    def forward(self, x):
        logits = self.logits(x)
        return nn.functional.some_squashing_function(logits)



# Training
data, target = ...
net = Network()
logits = net.logits(data)
loss = nn.functional.some_loss_with_logits(logits, target)
...

# Predicting
data = ...
probabilities = net(data)

Pros:

  • making a prediction now looks more pytorch-like

Cons:

  • during training, one has to use the custom method logits
  • this doesn’t play well with other things like wrapping the model with nn.DataParallel because then the logits method is not exposed

A third option would be to turn on and off the squashing layer based on the mode of the network, i.e. net.train() and net.eval(). However this would hide the network behavior even more.

1 Like

I believe the first one is much better. The squashing function does not change the results of inference; i.e., if you pick the class with the highest probability vs picking the class with the highest logit, you’ll get the same results.

So making a prediction is always the same, a variant of output.max(dim=1)[1]. In particular, in your first example, there is no need to compute probabilities first (unless you want to show them to the user / you are working on some edge case)

2 Likes