In an image binary classification problem, I’ve got an unbalanced dataset (10 000 positives for 200 000 negatives).
The metric I want to optimize is Precision : I need positive to be mostly true positive when an image is predicted as positive.
So as expected If I equally randomly split my train dataset (5000 POS / 5000 NEG) I got acceptable precision on an equally splitted test dataset (70%) but very poor on a REAL unbalanced dataset (24%).
How do i train my model for high precision with pytorch ?
Today i am using this optimizer :
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
Thank you for your reply!
Would you mind clarifying the difference between WeightedRandomSampling and over-/under-sampling in training when it comes to imbalanced inference?
As for the Focal Loss suggestion, I am not sure why it helps mitigate the issue when transferring from train to test, but I am definitely going to try it since I have not yet experimented with it. Thanks!
WeightedRandomSampler(WRS) is used when you don’t want to or can’t under-/over-sample your dataset. This blog explains WRS really well, would recommend checking it out.
When you assign class-weights to Cross Entropy Loss / BCELoss, every time your model makes an incorrect prediction, the loss is multiplied with the given class-weight. So ultimately, the model will push itself to learn meaningful representations of the minority class to minimize the loss. This works adequately 90% of the time.
For example: If you have two classes: Dog(2000 images), and Cat(100 images). If you define your loss as:
BCELoss(weights = Tensor([1.0 , 20.0]))
every time your model predicts cat image as a dog image while training, the loss will be multipled by 20.0, so to minimize the loss, model has to learn the meaningful representations of cat. Hence, you’ll get a better accuracy and precision.
But, still, over-sampling your dataset properly is the best option if you want better overall performance.