Train a custom classifier with limited number of classes

I need to train a classifier on Cifar10, such that it could classify the images with the labels 0 and 1 and for the labels other than these two labels the classifier provide unbiased (neutral) results. I try to train it using the following code with BCE loss:

       input = model(x)
       not_ind_tgt = ~((target==0) | (target==1))
       tgt_hot = F.one_hot(target, 10).float()
       multi_label = 1./8. * torch.ones(1,10).float()
       multi_label[0, 0: 2] = 0.0
       tgt_hot[not_ind_tgt] = multi_label
       loss = nn.BCEWithLogitsLoss()(input, tgt_hot).mean()

However, the accuracy of the trained classifier on 0 and 1 classes is zero.

  1. If you look at the docs, nn.BCEWithLogitsLoss() expects input and target of shape (N,*).
    Essentially, this means that if you use input of shape (N,10), then 10 individual BCE-Loss() will be performed, its sort of a multi-label BCE which is not your end goal.
  2. You must rather use a nn.CrossEntropyLoss() because you have two labels which are explicitly coded as zeroth and first class.

For example, let’s say input is of shape (3,10), with 3 being the batch-size.
Your target must be of shape (3,1) or just (3, ) and each entry in this variable will be the class associated for that particular batch.
In your case, the target must span between only 0 and 1, since you are focusing only on the zeroth and first class.

The data outside of these two labels is in my dataset and I want the classifier be neutral to those labels, i.e., gives zero probability for labels 0 and 1 for images which have different labels than 0 and 1. So, the targets can not be the span of 0 and 1.

Okay, got it.!
In that case, you could use something like class_weights in the loss function.
For more details, you can see the link https://pytorch.org/docs/stable/nn.html#crossentropyloss

I have 10 classes and each class has 1000 images in the dataset. It isn’t clear for me what are the weights in this case.
Do you think this could gives good accuracy on 0 and 1 labels? I tried but I failed.

You can find the class-weights using this link https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html
And just to make things clear about your problem of classifier providing neutral results.
Consider a batchsize of 1.

  1. When target is the 0th class. Loss will propogate normally.
  2. When target is the 1st class. Loss will propogate normally.
  3. When target is between 2nd to 9th class. How should the loss propogate ?

I applied the extreme weights [10, 10, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01,], however the accuracy of the classes 0 and 1 is still lower than the others.