torch.argmax
is used for a multi-class classification to compute the prediction from a model output in the shape [batch_size, nb_classes, *]
, where the argmax
is called on the nb_classes
-dimension.
Passing the logits to a softmax and calling torch.argmax
won’t make a difference, since the max. logit value will also have the highest probability.
Passing the logits to sigmoid and calling torch.argmax
sounds wrong, since in this case you should use a threshold.
It depends on your use case and the model output.
For a binary classification you can define the model output in the shape [batch_size, 1]
and output the logits (you would use nn.BCEWithLogitsLoss
in this case). To get the predicted class you can use a threshold on the logits or the probability after passing the logits to a sigmoid function.
Here is a small code example:
output = model(input) # output shape is `[batch_size, 1]` and contains logits
output_prob = torch.sigmoid(output) # calculate probabilities
pred = output_prob > 0.5 # apply threshold to get class predictions
Alternatively, you can treat the binary classification as a multi-class classification, where the model output would be [batch_size, 2]
(you would use nn.CrossEntropyLoss
in this case).
To get the predictions you would use preds = torch.argmax(output, dim=1)
on the logits or the softmax output.
torch.exp
is used to compute the probabilities after you’ve applied e.g. log_softmax
on the logits.