Balance between 2 MSE losses

I have an MLP and trying to predict 2 values from the MLP (so i’m doing a multi task learning). these values have two losses of different scales (one is in the tens and the other is in 100,000s). so I cant simply add the losses together before back propagating. I’m currently trying to learn a weight so I can use it to balance the losses but I dont know how do go about that. Below is a brief summary of the code



class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, 100)
        self.out1 = nn.Linear(100, output_dim - 1)
        self.out2 = nn.Linear(100, 1)
        self.relu = nn.ReLU()

        #create a learnable parameter that is used to weight the two losses
        self.weight = nn.Parameter(torch.ones(1, requires_grad=True))

        

    def forward(self, x):
        out1 = self.out1(self.relu(self.fc1(x)))
        out2 = self.out2(self.relu(self.fc1(x)))
        return out1, out2, self.weight

#in train method (assume we have the target values and data)
predicted_one, predicted_two, weight = model(data)

loss_one = nn.Functional.mse_loss(predicted_one, target_one)
loss_two = nn.Functional.mse_loss(predicted_two, target_two)

total_loss = loss_one + weight * loss_two


I dont think this is how to go about it because the the weight doesn’t learn. Please Kindly assist.

Hi Ays!

If it were me, I would just treat weight as a hyperparameter to be tuned
by hand.

There is, however, a proposed MultiTaskLoss (that I have never used) that
seems to be a theoretically-sound approach to automatically tuning the
relative weight between two losses.

Here is an implementation:

And here is the paper about the scheme that is referenced in the code:

Best.

K. Frank

Hi Frank. Thanks a lot for your response. I tried using the Multi-task Learning approach but im getting a loss as in the graph below.

Silly question:

Did you try normalizing the target values?

(say with -1 to 1, maybe tanh() final layer + MSE loss / L1 loss, and during inference, rescale the values accordingly)

Hi Arul,
Thanks for the suggestion. So my first loss has target values ranging between 0 and 1, my second loss has target values ranging between 0 and 4202496. So do you mean I should normalise the second target variable to values between 0 and 1? if yes, does that mean I wouldn’t need a weight in the loss because it would be of the same scale?

NB: sorry I’m a bit new to all these so its still trying to get a hang of it.

No worries. We all are learning:)

Yes, I meant to keep both the target variables on the same scale. (either [0,1] or [-1,1]).
The method proposed by Kendall et. al’s paper is from a different perspective to learn to weigh the losses based on the uncertainties (if the variance of a particular loss is very high, its contribution to the gradients will be low and vice versa). I would still keep it (or you can experiment without it too).

Please let me know how it goes.

Here is a simple function you can apply to normalize 2 or more rewards in order to balance various objectives:

import torch

def norm_rewards(raw_rewards, alphas, freq_weights, max_clip, scale = 1.):
    raw_rewards = torch.clip(raw_rewards, max = max_clip)
    return torch.sum(alphas*freq_weights*raw_rewards/max_clip)*scale

raw_rewards = torch.tensor([0.5, 2_000_000.]) # these are the dynamic raw rewards from each objective
alphas = torch.tensor([2., 1.]) #this should be priority of each reward; if all equal, can use torch.ones_like(raw_rewards)
freq_weights = torch.tensor([4., 1.]) #this should be the inverse of the frequency the reward occurs, i.e. 25% of the time would be 4. and 100% of the time would be 1.
max_clip = torch.tensor([1., 4_202_496.]) #this should be the max expected value of each reward and can also be used to clamp reward values for stability
scale = 1. #optional scale value which scales the overall reward

print(norm_rewards(raw_rewards, alphas, freq_weights, max_clip, scale))

Just to clarify the difference between the freq_weights and the alphas. The alphas are what you do to set priority of a given objective. I’ll give an example from a thought experiment recently in an alignment discussion.

Suppose you’re training a drone to strike a target. But once in every four missions, your operator sends an abort command, which hard codes the drone to stop the attack. In the simulation, the drone learns that by taking out the operator first, it can obtain a higher reward.

But now we also provide a second reward for the drone model for every time it follows the extra instructions from the operator, when given. So you might, in this case, set the alpha for that reward to some value higher than 1, to ensure the model prioritizes following instructions over hitting the target.

However, because that only occurs once every four missions, we also need to factor in this reduced frequency. If we do not, then the model may still simply ignore the command because the cumulative reward for hitting the target is greater. And that is what the freq_weights do. They balance the cumulative reward across instances.

Let me know if you have any questions and if it works out. Cheers

Hi @InnovArul and @J_Johnson,

Thanks both for your response. I scaled my second target value between 0 and 1 using minmaxscaler and it works okay. I ended up not using a learned parameter to weigh my loss cos it seemed as if the parameter was not learning - and everything was okay. Currently I’m facing an issue with my loss, it seems as if my model is not learning, I have a loss graph as shown below. please kindly assist.

It looks to me like it’s learning just fine. Loss starts out high and gets lower. What is the anticipated behavior and how is the accuracy? And how big is your dataset?

Things you might try to bring the validation loss lower are:

  1. Adjusting hyperparameters, such as learning rate, dropout, etc.;
  2. Trying different model architectures, including skip connections;
  3. Increase the number of hidden dims;
  4. Normalizing all input values to be between 0 and 1;

Regarding the learnable weight, this can be applied as the alphas in the previous function provided. But set it as torch.ones(2, requires_grad=True) for 2 rewards/losses.

By the way, you might want to separate your losses for printing metrics, as that may help as a sanity check to ensure both losses are improving.

The training loss flattens out after 1 epoch.