Local optima on imbalanced data

I have an image dataset with 50 classes. My dataset is highly imbalanced so that more than 75% of the images belong to the three most common classes (Class 12: 35%, class 32: 25%, class 41: 15%).
When I trained one of the torch vision architectures on my dataset without using class weights or balanced sampling, the model quickly finds that if it always predicts common classes in order (12, 32, 41,…) then it can have a low NLL loss (2.xx) and ~35% top-1 accuracy. I believe it is normal behavior. So far so good.

I have the exact same experience in Tensorflow (same dataset, same architecture, same optimizer, same hyperparameters,…).Tensorflow also does the same thing, find the trivial case quickly.

Here is my problem:

When I keep training the model, after a while, Tensorflow manages to escape the local optima and actually learn to classify images correctly. But, when I am using PyTorch, after a while, PyTorch also escapes the local optima (the loss decreases to 1.xx) for a few iterations but it quickly traps in the local optima again and never escapes!

I have tried different architectures, different optimizers, and different learning rates but this behavior happens every time. What is wrong with PyTorch or what is magical about the Tensorflow?

I checked my PyTorch code with a balanced dataset, and it works fine.

Does anyone have any idea? Any help is appreciated.


Now this might have something to do with how the initializations have been implemented in PyTorch and Tensorflow. To experiment you can try to train a small mlp, and override all initializations with say xavier or even all 0.5 etc. Would love to dig further if even then the results differ…

It was a bug in my preprocessing step!
I was doing normalization of the images wrongly. Now, PyTorch works fine.
Thanks for your help.