Normalization of losses with different scales

I am currently training to train a model to do localization (classification and bounding box prediction of images), but have encountered some trouble with regards to the difference between the magnitude of values of the different loss functions.

My model consists of a ResNet50 backbone with two heads: One for classification, and one for predicting the bounding box. My loss function is a sum of two different loss functions: cross entropy loss on the output from the classification head, and mean squared error loss on the output from the bounding box prediction head.

As I have gathered it is very common to simply sum different loss functions to create a total loss function that your model uses, however the difference in scale between the cross entropy output and the mean squared error output is very large. The cross entropy loss will typically have values on the order of 10^2 when training is starting, while the mean squared error loss will have values on the order of 10^5. This means that when the model is training, it will only focus on reducing the mean squared error loss, ie. it will only try to predict a better bounding box, as this is way better in terms of reducing the total loss function.

To solve this, I want to normalize the loss functions so they get a similar scale before I sum them for the total loss (This is described as a common approach in How to normalize losses of different scale). However, I am not entirely sure how to do this in pytorch without interfering with the training when doing the .backward() call on the total loss.

Without any rescaling, my loss calculations looked like this (prediction_scores is the predicted scores of images in batch, class_batch is ground truth labels of images in batch, prediction_boxes is predicted bounding box coordinates of images in batch, box_batch is ground truth bounding box coordinates of images in batch):

loss_class = F.cross_entropy(prediction_scores, class_batch)
loss_box = F.mse_loss(prediction_boxes, box_batch)
total_loss = loss_class + loss_box

total_loss.backward()

I have tried adding rescaling of the mse_loss in the following way:

loss_class = F.cross_entropy(prediction_scores, class_batch)
loss_box = F.mse_loss(prediction_boxes, box_batch)

with torch.no_grad():
     loss_box = torch.div(loss_box, 10000)

total_loss = loss_class + loss_box

total_loss.backward()

I included the torch.no_grad() in the hopes that when i call .backward() on total_loss, the division by 10000 is essentially ignored when training. Is this the right approach for what I am trying to achieve? Is there something else that I should do in order to rescale my loss functions?

Hopefully my question makes sense. Thanks in advance!

This is a problem that comes up in reinforcement learning. How do you balance multiple objective functions?

If you can identify a maximum loss value for each type of loss, you can use something like this:

20231111_112340

Where A_n represents a given objective function’s losses during the training. a_n represents the current loss. And gamma is just an additional parameter, in case you want one objective to be prioritized over another.

In Pytorch code, that would be:

# suppose we have loss_class and loss_box calculated, and have determined an approximate max value for each separately

total_loss = loss_class/loss_class_max + loss_box/loss_box_max

You could also get the mean of the above or reduce the learning rate, though it probably won’t make much difference with only 2 loss types.

Getting the max loss could be handle with a simple evaluation function between the calculation of individual losses and the calculation of total loss:

def max_loss_eval(current_loss, max_loss):
    if current_loss > max_loss:
        return current_loss.detach()
    else:
        return max_loss
2 Likes

Thank you for sharing this approach. Especially for PPO where there is a policy gradient loss, value loss and entropy loss, this seems to be useful. But so far I have never seen this mentioned elsewhere. Do you know any RL papers that use this type of max loss scaling?

No. I haven’t written any papers. Was just thinking about the problem one day and this was what I arrived at. Normalization is a pretty straightforward problem, though.

Assuming the model converges to the data, the max_loss will likely get defined in one of the first few passes, if not the very first, and then remain static throughout the rest. And so this guarantees each loss will be normalized between 0 and 1.