I am using a pre-trained MobileNetV2 model on a custom dataset. I configured the outputs from 1000 to 3 since that is the number of classes I am working with. Class A is 8000 images Class B is 5000 images and class C is 500 images. I took 100 of each out and use those as the validation set and the rest as the training set. I ran the experiment for 90 epochs with the learning rate being adjusted by 0.1 every 30 epochs and these hyperparameters:
learning rate = 0.001
weight decay = 4e-5
momentum = 0.9
batch size = 64
optimizer = SGD
dropout = 0.2
after completing my experiment I am noticing that class A and B have accuracies between 80% and 97% in the last few runs and Class C is in the 70 percentile range.
I understand this because of how small the data sample is for class C compared to A and B, and I am looking for ways to run around that since even if I add more data it still would be less than half the size of the other two classes. This is my data augmentation technique:
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224,scale = (0.2,1.0)),
transforms.RandomHorizontalFlip(), # reverse 50% of images
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])
Cross-validation is the only technique I found while searching this problem that I am not using and using cross-validation in Pytroch is proving rather tricky for me. One technique I read about in a paper is where the author claimed the augmented their class data to be 6,000 images each, but they didn’t say how they did this. Are there anyways I can increase the detection accuracy of class C with limited samples available and by association increase the overall accuracy of my model from 84% to 95+%?
Update
I came across the concept of weight loss in cross-entropy loss. the formula I found to determine each class weight loss is 1/class_size. So in my case class, A weight would be 1/8000 = 1.25*10^(-4), if this is not the correct way determine it please let me know.