For binary classification, you need only one logit so, a linear layer that maps its input to a single neuron is adequate. Also, you need to put a threshold on the logit output by linear layer. But an activation layer as the last layer is more rational, something like sigmoid.
Why should model see only class 1 samples? In the first post, you have mentioned that you are using two separate folders, one for class 1 and the latter for class 0, 2, 3, 4. So, the model will see all samples and learns class 1 as 1 and 0, 2, 3, 4 as class not 1 or zero.
You can define binary models to learn each separate class and combine them. This approach called OneVsAll.
Just to clarify something, for a binary-classification problem, you
are best off using the logits that come out of a final Linear layer,
with no threshold or Sigmoid activation, and feed them into BCEWithLogitsLoss. (Using Sigmoid and BCELoss is less
And, as Doosti recommended, your last layer should have a single
output, rather than 2. Thus:
The short answer is that you threshold your single logit output
against 0.0, rather than running a set of nClass outputs through argmax().
Let me confirm what I think you are asking:
In addition to calculating your loss function (used for training), you
often also want to calculate the accuracy of your predictions.
For a multi-class classification problem, you typically pass a set
of nClass predicted logits (or predicted probabilities) though argmax() to get the single predicted integer class label (that
you then compare with your kown class label). I assume that
this is the “argmax” you are talking about.
For a binary problem, your last Linear layer will output a single
predicted logit for the sample being in class-“1” (as opposed to
being in class-“0”). (Or, if you pass this logit through a sigmoid(),
you will get the predicted probability of the sample being in class-“1”.)
In this case you threshold the output to get a binary prediction: logit > 0.0 == True means you predict that the sample is
in class-“1” (and logit > 0.0 == False means class-“0”). (If
you are working with probabilities, then prob > 0.5 == True
means class-“1”.) You then compare this prediction with the known
class-“0” / class-“1” binary label for the sample in question.