Weighted sampling & Weighted CE loss not helping

You are right! The BatchNorm layer directly working on your input should achieve a similar effect.
However, it shouldn’t hurt the performance, so I would go for it.

I wasn’t sure, if BatchNorm achieved the same effect, so I created a small sandbox example to play around.
Not related to your topic, but maybe it’s interesting nevertheless: gist.

1 Like

Weighted sampling does not necessary improve overall accuracy. It is not a magic tool that can solve this problem.

Weighted sampling behaves the same as overweighting obs in the minority classes and downweighting obs in the majority classes. Therefore, it improves the accuracy in minority classes AT THE COST OF sacrificing the accuracy in the majority classes. When calculating the overall accuracy, depending on the accuracy metric you use, it does not guarantee that the the optimal solution for the weighted loss function is the same as the minimum solution for the accuracy metric.

The best practice is, rather than using a single accuracy metric across all classes, look at the confusion matrices with/without the weighted schema, and compare the accuracy of each class. In this way, you can see how accuracies across different classes are traded off, and whether the trade-off is reasonable for your purpose.

A follow-up question - when should I stop training? Clearly mean per class accuracy doesn’t necessarily increase if overall accuracy increases. I get higher mean per class accuracy when overall validation accuracy isn’t the highest.

You could track the mean per class accuracy or another metric of your validation set while training and save the best checkpoint of your model.
You would have to tweak the hyperparameters of your model and chose the one with your highest metric.

I meant to ask, when should I terminate my training? Should I terminate when overall train accuracy reaches nearly 100%?

I would terminate it, when your model doen’t train anymore or is overfitting, i.e. when the mean per class accuracy is reaching a plateau or is getting worse.

But the mean per class accuracy doesn’t increase as smoothly as overall accuracy. It fluctuates a lot after first few epochs of training. And suddenly out of no where I have a high mean per class accuracy.

I’m not sure if there is another valid approach to solve this issue.
I would handle the mean per class accuracy as any other metric you want to optimize, i.e. observe it and stop the training when your model seems good enough.

Depending on your dataset, you could also observe several other metrics like Sensitivity, Specificity etc.

1 Like