Pos_weigh for multiclassification task

Hi. I am using a CNN for grading breast cancer and it this problem falls into the category of image multiclassification. I have highly imbalanced classes on my dataset, due to high-grade cancer cases being scarce.

The output from the model is basically logits from the last Linear layer as follows:
self.classifiers = nn.Linear(size[1], n_classes)
where I extract the logits
logits = self.classifiers(M)

then I feed the logits to the error loss function on the training loop as:
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(logits, label)

How can I implement some sort of weight in this case, to penalize classes that have low representation in the dataset? I have in total 6 classes [0,1,2,3,4,5] and the ones with the lower representation are 3,4,5.

Hi Anita!

CrossEntropyLoss’s weight constructor argument reweights your classes
in the loss function. (CrossEntropyLoss’s weight is quite analogous to
BCEWithLogitsLoss’s pos_weight constructor argument.)

I’m not sure I understand what you mean by “penalize classes that have low representation.” Typically you would want to weight more heavily the classes
that occur less frequently in your dataset, and the reciprocal of the frequency
with which each class occurs is a common choice for the class weight.

Thus, if classes 0, 1, and 2 occur tens times as often as classes 3 and 4,
and twenty times as often as class 5, you could use;

class_weights = 1.0 / torch.tensor ([20.0, 20.0, 20.0, 2.0, 2.0, 1.0])
loss_fn = torch.nn.CrossEntropyLoss (weight = class_weights)

as your weighted loss criterion.

(Alternatively, you could sample your less-common classes more frequently,
for example, by using WeightedRandomSampler. I tend to think that there
are some theoretical reasons to prefer weighted sampling over class weights
in the loss function for the use case you describe, but both approaches are
sensible and both are commonly used.)

Best.

K. Frank

1 Like