I am currently training to train a model to do localization (classification and bounding box prediction of images), but have encountered some trouble with regards to the difference between the magnitude of values of the different loss functions.
My model consists of a ResNet50 backbone with two heads: One for classification, and one for predicting the bounding box. My loss function is a sum of two different loss functions: cross entropy loss on the output from the classification head, and mean squared error loss on the output from the bounding box prediction head.
As I have gathered it is very common to simply sum different loss functions to create a total loss function that your model uses, however the difference in scale between the cross entropy output and the mean squared error output is very large. The cross entropy loss will typically have values on the order of 10^2 when training is starting, while the mean squared error loss will have values on the order of 10^5. This means that when the model is training, it will only focus on reducing the mean squared error loss, ie. it will only try to predict a better bounding box, as this is way better in terms of reducing the total loss function.
To solve this, I want to normalize the loss functions so they get a similar scale before I sum them for the total loss (This is described as a common approach in How to normalize losses of different scale). However, I am not entirely sure how to do this in pytorch without interfering with the training when doing the .backward() call on the total loss.
Without any rescaling, my loss calculations looked like this (prediction_scores is the predicted scores of images in batch, class_batch is ground truth labels of images in batch, prediction_boxes is predicted bounding box coordinates of images in batch, box_batch is ground truth bounding box coordinates of images in batch):
loss_class = F.cross_entropy(prediction_scores, class_batch)
loss_box = F.mse_loss(prediction_boxes, box_batch)
total_loss = loss_class + loss_box
total_loss.backward()
I have tried adding rescaling of the mse_loss in the following way:
loss_class = F.cross_entropy(prediction_scores, class_batch)
loss_box = F.mse_loss(prediction_boxes, box_batch)
with torch.no_grad():
loss_box = torch.div(loss_box, 10000)
total_loss = loss_class + loss_box
total_loss.backward()
I included the torch.no_grad() in the hopes that when i call .backward() on total_loss, the division by 10000 is essentially ignored when training. Is this the right approach for what I am trying to achieve? Is there something else that I should do in order to rescale my loss functions?
Hopefully my question makes sense. Thanks in advance!