# Choosing weights for Weighted Cross Entropy Loss

I am trying to train a U-Net for image segmentation. My dataset consists of 80x80 pixel images (so 6400 pixels per image), and each image can be segmented into 3 parts: primary background, secondary background, and a third class that can be any one of 9 separately defined classes. So I have 11 classes in total.

My goal is to use image segmentation to determine what that third class is for each image. I decided to sample roughly 1000 images from the dataset and this is what I observed: on average ~5587 pixels belonged to the primary background class, ~802 pixels belonged to the secondary background class, and ~11 pixels belonged to the third class (which again, can be any one of the nine special classes).

Obviously, this is a very unbalanced dataset, so I was thinking about using weighted cross-entropy loss as my criterion. How should I go about choosing these weights? Any input will be appreciated.

Hi Entropy!

The conventional wisdom suggests that you should weight each
class inversely to the number of samples (pixels) in that class.

Now your â€śthird classâ€ť is actually nine â€śspecial classes.â€ť Letâ€™s
assume (since you donâ€™t give the numbers) that those â€śspecial

Then I would weight your eleven actual classes as follows:

primary background: 6400 / 5587
secondary background: 6400 / 802
third class - 1: 9 * 6400 / 11
third class - 2: 9 * 6400 / 11
â€¦
third class - 9: 9 * 6400 / 11

often, you would not want to weight them all the same, but rather in
inverse proportion to how frequently they occur.)

Lastly, the exact values of these weights shouldnâ€™t really matter. You
just need them to be approximately right to smooth out the significant
imbalance among classes in your dataset.

Best.

K. Frank

2 Likes

Thank you for replying! They do occur equally often so thatâ€™s nice! Since those values donâ€™t need to be exact and I need to be somewhere in that ballpark, is there an estimate as to how much I can be off and still have a decent loss function.

Hi @Entropy,

Kind share more light in getting average class count for each class. Did you summed up the 1â€™s in each pixel and takes it average for a particular class?