In an image binary classification problem, I’ve got an unbalanced dataset (10 000 positives for 200 000 negatives).
The metric I want to optimize is Precision : I need positive to be mostly true positive when an image is predicted as positive.
So as expected If I equally randomly split my train dataset (5000 POS / 5000 NEG) I got acceptable precision on an equally splitted test dataset (70%) but very poor on a REAL unbalanced dataset (24%).
How do i train my model for high precision with pytorch ?
Today i am using this optimizer :
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
Thanks all for your help,
Hey! Did you solve the issue after all? I’m struggling with the same thing at the moment
You have data imbalance problem, to alleviate that, You can,
- Use weighted Cross Entropy loss. Or
- Use WeightedRandomSampler() , see this post for more clarification.
After that you can experiment with different optimizers, and learning rates by using learning rate schedulers.
Thank you for your reply!
Would you mind clarifying the difference between WeightedRandomSampling and over-/under-sampling in training when it comes to imbalanced inference?
As for the Focal Loss suggestion, I am not sure why it helps mitigate the issue when transferring from train to test, but I am definitely going to try it since I have not yet experimented with it. Thanks!
WeightedRandomSampler(WRS) is used when you don’t want to or can’t under-/over-sample your dataset. This blog explains WRS really well, would recommend checking it out.
When you assign class-weights to Cross Entropy Loss / BCELoss, every time your model makes an incorrect prediction, the loss is multiplied with the given class-weight. So ultimately, the model will push itself to learn meaningful representations of the minority class to minimize the loss. This works adequately 90% of the time.
For example: If you have two classes: Dog(2000 images), and Cat(100 images). If you define your loss as:
BCELoss(weights = Tensor([1.0 , 20.0]))
every time your model predicts cat image as a dog image while training, the loss will be multipled by 20.0, so to minimize the loss, model has to learn the meaningful representations of cat. Hence, you’ll get a better accuracy and precision.
But, still, over-sampling your dataset properly is the best option if you want better overall performance.
Hope it helps!