Class imbalance with image segmentation

Hi, I am trying to deal with the class imbalance problem in image segmentation.
And as my target mask size is 128x128. If I have 3300 training image, my total number of target data (if concatenated ) will be 422400 x 128.
I tried to get the total_mask which has the size torch.Size([422400, 128])
And I try to build the sampler with this code below.

unique_color, count = np.unique(mask_total.cpu(), return_counts = True)
weight = 1. / count
samples_weight = weight[mask_total]

It give me this error

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

since my mask_total is 422400 x 128.
So if my images are 128 x128, to get the samples_weight, what should I implement?
Thank you a lot.

Could you explain your use case a bit more and how you would like to sample the masks?
Usually you would create sample weights, which would assign a certain weight to each sample in the Dataset (image and mask pair). In your description it looks like you would like to sample pixels from your mask? Is that correct or do I misunderstand your use case?

Also, for reference: this answer might belong into this thread.

I’m writing a multilabel classifier where the targets are a multi-hot encoded vector for 13 classes like below

0              "I'm a category 7 sample"  ...  [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
1              "I'm a category 8 sample"  ...  [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]
2        "I'm a category 1 AND 6 sample"  ...  [1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]

The Dataset is heavily imbalanced, with the following distribution for 20,000 samples:

{'Cat 1': 450,
 'Cat 2': 364,
 'Cat 3': 37,
 'Cat 4': 334,
 'Cat 5': 630,
 'Cat 6': 1096,
 'Cat 7': 918,
 'Cat 8': 3324,
 'Cat 9': 2053,
 'Cat 10': 532,
 'Cat 11': 1110,
 'Cat 12': 101,
 'Cat 13': 776}

There’s roughly 10,000 unlabelled samples in this dataset.

To adress the imbalance in training, I am trying to use WeightedRandomSampler, based on this answer (Some problems with WeightedRandomSampler) by “ptrblck”

Here’s the relevant code:

# dist.values = [450, 364, 37, 334, 630, 1096, 918, 3324, 2053, 532, 1110, 101, 776, 9967]

weights = [1./v for v in dist.values()]
weights = torch.tensor(weights, dtype=torch.float)

# weights: tensor([0.0022, 0.0027, 0.0270, 0.0030, 0.0016, 0.0009, 0.0011, 0.0003, 0.0005,
#        0.0019, 0.0009, 0.0099, 0.0013, 0.0001])

y = torch.tensor(train_dataset['list'], dtype=torch.long)
# y:
# [   [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
#     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#     ...
#     [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
#     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#     [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],]

samples_weights = weights[y]

# for 17,000 training samples
# samples_weights torch.Size([17000, 13]):
[
[0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0027, 0.0022, 0.0022, 0.0022,
        0.0022, 0.0022, 0.0022, 0.0022],
[0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022,
        0.0022, 0.0022, 0.0022, 0.0022],
...
[0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022,
        0.0022, 0.0027, 0.0022, 0.0022],
[0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0027, 0.0022, 0.0022,
        0.0022, 0.0022, 0.0022, 0.0022],
[0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022,
        0.0022, 0.0022, 0.0022, 0.0022],]

Since my targets are multi-hot encoded, samples_weights will either be valued as weights[0] or weights[1], which is obviously wrong.

Is there any way I can make this work with the multi-hot targets?

Shouldn’t samples_weight be the inplace multiplication between the weights vector and the targets for that sample?
For example, the samples_weight for sample [0] above would be:

[0, 0, 0, 0, 0, 0.0009, 0, 0, 0, 0, 0, 0, 0]

This is all new to me so I appreciate any insight.

1 Like

If I’m understanding what you’re after, you really need to do

samples_weights = y*weights.T

but to do that, you’re going to have to either add a 14th element representing “unlabeled” for every sample in y, OR cut out the 14th element of weights that represents the unlabeled category.

Also, welcome to the forum. Hope this helps and good luck with your project! :wink:

1 Like

Thank you!

Good eye on the 14th “class”! (I left it by mistake because I wanted to count the unlabelled for this post.)

I’ll re-run the training and see if this works.

1 Like