How to make use of class weights + data augmentation together?

Does torchvision.transforms or albumentations provide some means to make use of the class weights calculated for the training data?

There are several arguments against using class weights in loss functions and performing data augmentation at the same time. Mostly the same reason is stated that

the augmentation process changes the original class distribution

This made me look for ways to force the augmentation process to consider minority class samples more often than the majority class samples.

What if class weights are computed on the fly for an augmented batch of training examples (as opposed to the traditional way of calculating for the whole training data)? This way both techniques may work together.

Your experiences with the data augmentation to counter the class imbalance problem will be highly appreciated :slightly_smiling_face:

If you are using CrossEntropyLoss, my idea is that you could change the weight parameter on the fly for your batch. Something along the lines of:

criterion = nn.CrossEntropy()

....
for input, target in dataset:
    cel_weights = calculate_weights(target)
    criterion.weights = cel_weights
    ...

If you are not using CEL, maybe check if your criterion has the weight parameter, so you could change it.

Thanks, @Manuel_Alejandro_Dia for your reply!

Let me rephrase the questions:

  1. Is it a general rule that the class weights should be pre-calculated for the entire training data rather than calculating for each batch on the fly?
  2. Is the use of “class weights” in loss calculation and “data augmentation” mutually exclusive?
  3. Should one use weights calculated from each input batch of A) augmented images or B) images before augmenting them?

From my experience, yes. Here you just do the processing of the weights in the dataset once, instead of doing it on each batch.

That is a good question. I would assume that depending on your data augmentation you would have to decide if you recalculate or not your weights. For example, the usual data augmentation I do in images is a horizontal flip and re-scaling, and I would argue that by doing these I am not modifying the ratio of each by such an amount that would I would need to recalculate the class weights.

If you think that the data augmentation changes somehow the ratios between classes by a considerable amount, then I would suggest you to do one of these things:

  • Update the weights on the fly, but be aware that some classes might not be present when you do the weight recalculation.
  • Apply the augmentation to your entire dataset and get the weights from there. But since some augmentations are based on random values, I would suggest you run this a few times to get an average of the class weights after augmentation.

Thank you for the questions!
I hope this helps!