Hello, I’m a total newbie in image classification and I got really confused from all the available techniques for each step of the implementation of an image classification app - from the data preparation to model deployment.
My idea is to learn by doing some real project. I chose to implement a mobile app that classifies an image from the phone camera. For the purpose of mobile inference I chose to test and decide between two pretrained models (transfer learning) - MobileNetv2 and Yolo 11m-cls.
The problem is that my dataset is small and imbalanced. I have one class with 2567 images, another with 1180 images, another with 880, another with 192, etc and the smallest class has 69 images. I’m really stuck figuring out how to approach this case - transfer learning on imbalanced multi-class image data for mobile app.
Should I use only data augmentation in the data prep stage or combine it with some other method in another stage (training, evaluation)?
Or data augmentation is not suitable in this case and maybe I should use some other approach? I read about stratified split and stratified k fold but I don’t understand if I should apply only one of them, should I combine both of them or should I combine one of them with something else? Or should I use totally different approach?
I just can’t get clear picture how to do this and what to do in each stage.
Any advice appreciated, thank you!
Hi!
Most real-world applications have the issue of imbalanced data. There are a few things you can try to increase the performance of the model.
Class weighting
You can apply a weight to each class that is inversely proportional to the number of images of that class in the dataset. It is often better to also normalize this weight. This will basically penalize the model more if it misclassifies an image with a class that is less common and tries to force the model to get these classes right as well.
Data augmentation
You can try to “generate” more samples by creating augmented forms, like rotations, flipping, noise, etc. This will try to emulate you having more data, while this is actually not the case. However, with more complex models it is hard to “fool” them. Furthermore, the model might also take the easy way out then and just detect images with augmentation and images without instead of the actual visual appearance of your objects. Nonetheless, it is always good to apply some level of augmentation to your training images to get better model generalization.
Synthetic data
With the current state-of-the-art generetive models, is it possible to generate high quality synthetic images for the classes where your data is limited. You can either generate images from scratch, or even use your current images to generate new variants.
Custom loss
Often some classes are confused by the model more often than others. They may look similar or you just do not have enough data for the model to find the distinction. In that case you can also create a custom loss function that penalizes the model more when it makes a classification mistake between these two classes.
Stratification can help to properly split the imbalanced dataset into a training and validation set, making sure that you have equal distributions of samples in either split. You can also take a random subset, which should often be sufficient as well, but you do have to check if the distributions of the splits is ok afterwards (stratification does this for you). Finally, make sure that there is no data leakage between the splits as this will give skewed performance metrics. This can be an issues with stratification if images of similar objects are next to each other; stratification often then splits them into training and validation which will give you data leakage.
Thank you! If I use class weighting is it ok to combine it with custom loss? Or I should use only one of the proposed techniques?
Also if I:
- First calculate the class weights from the complete dataset and
- Then do a stratified train/val/test split,
- Then augment only the train classes,
- Apply the calculated class weights from 1 to train the model - is that ok, when the train data is already augmented?
Here’s how I’d structure your experiment:
- Start with stratified train-validation-test split (70-15-15)
- Apply balanced augmentation to training data only
- Implement class weights in your loss function
- Train both MobileNetV2 and YOLO11m-cls with the same setup
- Use stratified 5-fold CV on your training set for hyperparameter tuning
- Evaluate on your held-out test set using multiple metrics
- Test inference speed on actual mobile device
For your mobile app, also consider:
- Model quantization after training to reduce size
- Test-time augmentation (TTA) for better predictions at slight computational cost
- Confidence thresholds - maybe refuse to classify if confidence is too low
Remember, with your class having only 69 images, you’re pushing the limits of what’s possible. If performance on minority classes remains poor, you might need to either collect more data for those classes or consider merging similar minority classes if semantically appropriate.
The class weighting can be independently applied from the augmentations. It is just a matter of how large the penalty is for a model to misclassify that class. It does not matter if the data has been augmented or not.
As long as the distribution of classes of the full dataset and training dataset is mostly equal, it should be fine to use the same class weights.
@Hamza_Javaid, also added some nice steps that complement mine!
However, I would start simple with for instance only class weights and work your way up by adding more complexity over time, like a custom loss function to see if this would improve your performance. You also don’t need a custom loss yet if you just use class weights; this is already implemented in the standard CrossEntropyLoss of PyTorch (see CrossEntropyLoss — PyTorch 2.7 documentation).
Thank you very much for the clear and structured advice. That makes things a lot more perceptible and answers some of the questions that I had - for example one of the things that I wondered was if I do a stratified dataset split can I apply a stratified k-fold cross validation (points 1 and 5 from your advice are positive about that combination).
Thank you again!
Thank you for your clear and descriptive advices! I can see a light in the tunnel now.