How to handle Multiple Losses

I have this model depicted in the figure. Model 1 and model 2 used to be two disjoint models such that they worked in a pipeline that we first train model 1 till convergence and feed the preprocessed outputs to model 2 as inputs. I am now training them end to end and I am struggling with how to integrate these two losses instead of just using loss 2. Other quesiton is, can I use two different optimizers for each of their parameters.

Should I do loss1.backward() first to update the gradients for the first model and then do loss2.backward() which will update the gradients for both model 1 and model 2 parameters. Do you think this is a good idea where gradients can be updated from both losses with a controlled learning rate so that I can force model 1 to learn more from loss1 than from loss2?

Another idea that came to my mind is to sum both loss1 and loss2 (let’s call it loss3) and backpropagate. I have the initial idea that loss3 will only backpropagate loss2 until it reaches c and then backpropagated the weighted sum. Is that right?

Any ideas or references to the literature will be appreciated.

1 Like

Hello Ahmed!

Yes, this is, in effect, what will happen.

Summing the losses is the approach I would take. I would probably
include relative weights: loss = w1 * loss1 + w2 * loss2.

Note, that loss1 does not depend on the weights in modules D and
E, so backpropagating loss1 + loss2 will have the same effect on
modules D and E as backpropagating loss2 alone.

But, as you correctly reason, both loss1 and loss2 will affect
modules A, B, and C when backpropagating loss1 + loss2.

Be aware though, that including loss1 could push push A, B, and
C away from where they would provide the best performance for
the output of your overall model (and lead to higher values of
loss2).

On the other hand, including loss1 could help prevent A, B, and
C from overfitting your training set, potentially providing better
performance on your test set.

Good luck.

K. Frank

2 Likes

Hi Frank,

Thanks for your elaborate answer. Your explanation makes sense. I am actually running an experiment with that setup and it seems that the loss is decaying. Thank you!