Computing sample weights for multi class image segmentation

To use weighted random sampler in PyTorch we need to compute samples weight. In case of image classification where we have one label for each class, we can compute samples weight using sklearn library. However, for segmentation, for each image we have multiple classes. I was wondering how we can compute the samples weight in this case?

Hi Carol!

I will answer in the context of semantic segmentation where you classify
each pixel in an image (rather than assigning the entire image to a single
class).

To first approximation, you won’t want (or be able) to use sample weights
for semantic segmentation. Instead you will want to use class weights
(e.g., as supported by the CrossEntropyLoss’s weight constructor
argument).

Consider an example such as this:

You have images consisting of mostly grey background pixels (class-0).
They contain a smattering of red (class-1), green (class-2), and blue
(class-3) foreground pixels and you want to perform four-class semantic
segmentation. Let’s say that the (smallish number of) foreground pixels
in some images are mostly red, in others, green or blue, and in others
mixed.

Let’s assume that no typical image type is uncommon in your data set,
so that in this sense your data set is balanced. But from a classification
perspective, your data set is highly unbalanced because the large majority
of your pixels are background pixels, so class-0 is highly over-represented.

This is a problem because your network could simply learn to always
predict class-0 for a pixel and it would almost always be right.

But you can’t fix this problem with sample weights (and a weighted
random sampler) because all of your samples (images) have class-0
pixels highly over-represented.

Instead, use class weights in your loss function to underweight class-0
pixels. Now if your network tries to predict class-0 for all pixels, the
large number of correct class-0 predictions will be underweighted, so
the network will have to learn to predict class-1, class-2, and class-3
pixels correctly to do well with the loss function.

(As a counter-example, let’s say most of your images are about half
background and half foreground. Let’s further say that most of your
images have red foreground pixels and only a few have green or blue
foreground pixels. Now it would make sense to use sample weights
where you increase the probability of sampling green and blue images
relative to red so that a batch, on average, consists of a third each of
red, green, and blue images.)

Best.

K. Frank

Thank you very much for the explanation.