Multiple loss functions


I am working on a problem where I am using two loss functions together i.e.
Total_loss = cross_entropy_loss + custom_ loss
And then Total_ loss.backward().

The value of Cross entropy loss for a training of say 20 epochs, reaches to ~0.7 from 2.4. It converges faster till approx. 8th epoch. Thereafter very low decrement.

While the value of custom loss decreases almost exponentially/linearly over entire training, starting from 3.2 to -10 (negative) approx.

That means after certain epoch model focuses mainly on custom loss and almost ignores cross entropy loss!

So, what can I do to make the model treat both losses equally and give them dual importance.?

I am currently multiplying a small number to custom loss,but it doesn’t seem to work.

Just an idea but maybe you could clip the custom loss so eventually it stops going down. Then when the model hits that point it might start focusing on the cross entropy loss.

Hi @Hdk,

There could be a lot of things at play here depending on your custom_loss. I would typically try the following things,

  1. If possible, try to make the custom loss’s lower bound as zero.
  2. If the above is not possible, do Total_loss = (10/2.4)*cross_entropy_loss + custom_loss

@Dwight_Foster As I mentioned, I want the model to treat both losses equally i.e. both should go down ideally. So, I don’t want to discard custom loss at some point of time.

First point is not possible, custom loss can be negative. I already tried second one, but then custom loss stopped decreasing i.e. over entire training I saw almost flat line of loss, while cross entropy was indeed decreasing. It made things even worse. So, I guess it was focusing on just cross entropy loss.

I was thinking to implement a dynamic multiplier instead of static which will be multiplied with custom loss i.e. this multiplier will also change over entire training.
Any idea if that makes sense, how to decide which equation (e.g. ax**2+b…) should that multiplier follow and how it can be implemented in training loop?

Thank you for the reply.