# Loss weighting for imbalanced classes

My aim is to predict a star rating from 1-5 based on a yelp review. I’m training a network on a yelp dataset that is severely skewed with 4 star ratings. I’d like to penalize training such that samples with a 4-label do not play as much of a role.

I’m using nn.CrossEntropyLoss(), and I understand that when I call my loss function with the score and train labels I can pass in a weighting vector. My questions are:

1. Should my weight vector values be calculated from the total distribution or the distribution of the current batch that I’m looking at the specific iteration where I call the loss function?

2. How should the weighting vector look? If I have 100 examples and the distribution looks like this 1: 10, 2: 10, 3: 10, 4: 50, 5: 20, would I want:
[1/10 1/10 1/10 1/50 1/20] or [10 10 10 50 20], and why?

Should my weight vector values be calculated from the total distribution or the distribution of the current batch that I’m looking at the specific iteration where I call the loss function?

You should calculate the weight distribution from your training set.

How should the weighting vector look? If I have 100 examples and the distribution looks like this 1: 10, 2: 10, 3: 10, 4: 50, 5: 20, would I want: [1/10 1/10 1/10 1/50 1/20] or [10 10 10 50 20], and why?

You should be weighing in the inverse ratio - by that, I mean classes with more examples should have a lesser weight (because you want to make sure loss goes down for rarer examples). For classes A, B with 90 and 10 samples respectively. The weights I would use is 0.1 for A and 0.9 for B.

Right now I am penalizing the imbalanced classes by assigning the following weight vector in CrossEntropyLoss:

[1/(# instances in class 0), 1/(#instances in class 1) , … , 1/(# instances in class n)]

Do you see drawbacks to this approach other than slower learning? The other approach is to do the following:

p_i = (#instances in class i)/(# total samples in set)
p_i’ = 1- p_i

Then the vector would be:
[p_0’, p_1’, …, p_n’ ]