How to learn the weights between two losses?

I already find the example of paper in keras with code same like you. But I don’t know why my eta just keep increasing.
I give you some parameter per epoch:

epoch 1:
list loss: [ 211.6204, 283.3055,  276.5063] and eta: [5.0511, 5.0714, 5.0698]
epoch 2:
list loss: [210.646, 281.631, 275.2699] and eta: [5.2132, 5.2701, 5.2673]
epoch 3:
list loss: [ 211.3304, 282.8942, 276.3101] and eta: [5.3005, 5.4210, 5.4148]
epoch 4:
list loss: [ 211.3207, 282.6045, 276.2361] and eta: [5.3320, 5.5211, 5.5101]

If I don’t think wrong. the loss_1 = torch.Tensor(loss) * torch.exp(-self.eta) = [3.3475, 3.4172, 3.4132]
and loss_2 = self.eta = [5.3320, 5.5211, 5.5101]
loss_2 > loss_1 and if keep increase eta, the loss_2 is still greater more than loss _1. But why me eta still increase.
my code when compute loss:

self.eta = nn.Parameter(torch.Tensor(cf['eta']))
loss_combine = torch.cuda.FloatTensor([loss_1.sum(), loss_2.sum(), loss_3.sum()]) * torch.exp(-self.eta) + self.eta
#                 print("loss combine: ", loss_combine)
                loss_combine = loss_combine.sum()
return loss_combine

And 1 question about the solution: Approx. optimal weights mention in paper table 5. Does it use the sum weighted loss, mean i use grid search to choose the weighted for each loss and summarize like 1/2 * loss_1 + 1/3 * loss_2 + 1/5 * loss_3 ? Is it right?
And 1 question about the reason, why loss don’t in the same scale make the total loss uniform make the 1 task can converge and 2 task not converging. In my case, if i use simple loss sum uniform, loss_1 after 200 epoch approxi 0.5, loss_2 approxi 1.2, and loss 3 greate then 7. I try to search paper or more keyword but not have.

Thank you

I cannot answer soon. But, I think optimal weights are not used in a recent paper of natural language understanding:

Please see Algorithm 1 in this paper.

I have understood that the total loss was decreasing from the following calculation:

>>> def total_loss(loss, eta):
...     loss = torch.Tensor(loss)
...     eta = torch.Tensor(eta)
...     return (loss * torch.exp(-eta) + eta).sum()
... 
>>> total_loss([ 211.6204, 283.3055,  276.5063], [5.0511, 5.0714, 5.0698])
tensor(20.0620)
>>> total_loss([210.646, 281.631, 275.2699], [5.2132, 5.2701, 5.2673])
tensor(19.7656)
>>> total_loss([ 211.3304, 282.8942, 276.3101], [5.3005, 5.4210, 5.4148])
tensor(19.6715)
>>> total_loss([ 211.3207, 282.6045, 276.2361], [5.3320, 5.5211, 5.5101])
tensor(19.6332)

I think that the uncertainties increase in the beginning but begins to decrease after some epochs, as shown in Figure 7 of the paper. You might need to optimize the learning rate.

Because sigma^2 must be near loss, eta can be estimated using the initial losses as

>>> torch.log(torch.Tensor([ 211.6204, 283.3055,  276.5063]))
tensor([5.3548, 5.6465, 5.6222])

I think the maximum of eta is somewhat greater than the estimated value.

Figure 2 of the paper shows the performance depends on weights. The total loss is given by Equation (1) where the sum of weights is 1.

Because I am not an expert of multi-task learning, you should make a new topic about multi-task learning on this site.

1 Like

I will read more detail from paper. Anyway, thank you Tony.

You might have to use mean(), i.e. not sum().

This paper proposes total loss composed of MSE and CrossEntropy losses. Other losses are outside the scope of the assumption. An implementation for Equation (10) where y1 is a continuous output and y2 is a discrete output:

import torch
import torch.nn as nn
import torch.optim as optim

class MultiTaskLoss(nn.Module):
    def __init__(self, model, loss_fn, eta):
        super(MultiTaskLoss, self).__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.eta = nn.Parameter(torch.Tensor(eta))

    def forward(self, input, targets):
        outputs = self.model(input)
        loss = [l(o,y) for l, o, y in zip(self.loss_fn, outputs, targets)]
        total_loss = torch.Tensor(loss) * torch.exp(-self.eta) + self.eta
        return loss, total_loss.sum() # omit 1/2

class MultiTaskModel(nn.Module):
    def __init__(self):
        super(MultiTaskModel, self).__init__()
        self.e  = nn.Linear(5, 5, bias=False)
        self.f1 = nn.Linear(5, 2, bias=False)
        self.f2 = nn.Linear(5, 3, bias=False)

    def forward(self, input):
        x = self.e(input)
        outputs = [self.f1(x), self.f2(x)]
        return outputs

## For the normal distribution,
loss_fn1 = nn.MSELoss()
## For the Laplace distribution,
# loss_fn1 = nn.L1Loss()
##
## Note the original work uses the L1 loss for Instance Segmentation
## and Depth Regression, as described at page 6.
## https://arxiv.org/abs/1705.07115
##

cel = nn.CrossEntropyLoss()
def loss_fn2(x, cls):
    return 2 * cel(x, cls)

mtl = MultiTaskLoss(model=MultiTaskModel(),
                    loss_fn=[loss_fn1, loss_fn2],
                    eta=[2.0, 1.0])

print(list(mtl.parameters()))

x = torch.randn(3, 5)
y1 = torch.randn(3, 2)
y2 = torch.LongTensor([0, 2, 1])

optimizer = optim.SGD(mtl.parameters(), lr=0.1)
optimizer.zero_grad()
loss, total_loss = mtl(x, [y1, y2])
print(loss, total_loss)
total_loss.backward()
optimizer.step()

All of my loss from 3 task are the same, its a CrossEntropyLosses. So I think

loss_combine_tensor = torch.cuda.FloatTensor([loss_1.sum(), loss_2.sum(), loss_3.sum()]) 
or
loss_combine_tensor = torch.cuda.FloatTensor([loss_1.mean(), loss_2.mean(), loss_3.mean()])

are the same value, just difference the gradfn = sum, or mean backward. Can you tell me what is the difference of each function when optimizer run ?

You need neither sum() nor mean() if you use CrossEntropyLoss() with default parameters. The CrossEntropyLoss must be multiplied by 2 according to Equation (10) in the paper. A sample code:

cel = nn.CrossEntropyLoss()
def loss_fn2(x, cls):
    return 2 * cel(x, cls)

sum() or mean() for loss_1, loss_2, and loss_3 doesn’t influence optimizations.

Hi Tony-Y,
Your example is great. I am a beginner in pytorch. I am using multi-task approach for two different task and want to adopt this approach. I have a resnet50 as backbone and added two branch fro two different task. Now I want to use this Multitask loss for these two task. Can you please briefly show How do i use your MultiTaskLoss class in my case? Below is my code


import torch
import torch.nn as nn
import torch.nn.functional as F


class multi_output_model(torch.nn.Module):    
     def __init__(self, model_core,cup_nodes,bbc_type_nodes):
        super(multi_output_model, self).__init__()

        self.resnet_model = model_core
        self.cup_nodes = cup_nodes
        self.bbc_type_nodes = bbc_type_nodes
          
           
        ''' heads ______________ '''
        self.y1o = nn.Linear(256,self.cup_nodes)
        nn.init.xavier_normal_(self.y1o.weight)
        self.y2o = nn.Linear(256,self.bbc_type_nodes)
        nn.init.xavier_normal_(self.y2o.weight)
        
    def forward(self, x):
       
        x1 = self.resnet_model(x)
        y1o = F.softmax(self.y1o(x1),dim=1)  
        y2o = torch.sigmoid(self.y2o(x1))   
        return y1o, y2o

Now I am calling this like:

model= multi_output_model(pretrainedImgnetModel,cup_nodes,bbc_type_nodes)
criterion = [nn.CrossEntropyLoss(),nn.CrossEntropyLoss()] # two loss function for two task

and during training I was doing like

loss0 = criterion[0](outputs[0], torch.max(cup.float(), 1)[1])
loss1 = criterion[1](outputs[1], torch.max(bbc_type.float(), 1)[1])
totalLoss = loss0+loss1

Thanks a lot in advance.

First of all, you need to check the document of nn.CrossEntropyLoss. F.softmax should not be applied to y1o because it is included in CrossEntropyLoss. In addition, you should confirm whether the application of sigmoid to y2o is appropriate.

Hi Tony-Y,
My branch y1o is for multi label classification and y2o is for binary classification. So to get values between 0 to 1, sigmoid for y2o and for summing up all ouput probabilities to 1 for multi label using softmax to y1o.
Please correct me If I am wrong.

The answer of this question is helpful. For binary classifications, there are three methods.

Thanks a lot Tony. I have read it. But for Binary classification BCELoss and CrossEntropy Loss should be same. Is not it? In that case my code should be ok ?

It would be nice if you could little bit explain How I can adopt MultiTaskLoss for my case :slight_smile:

BCELoss is not the same as CrossEntropyLoss. Before considering multitask learning, you have to learn how to use loss functions in PyTorch.

Have you figured out why losses do not change? I have the same problem: losses keep stable while values of sigma are changing.

I think we should not use torch.nn.XX or torch.nn.functional.XX to get losses in the forward function. For those who are stuck here because losses do not change, I have reimplemented the example from the author of the paper using PyTorch: PyTorch Exmple.

1 Like


I wrote an example code and it seemed to be working.
It might be the key to make optimizers recognize the learnable parameters (multi task loss’s sigmas).

1 Like

Hi,
You mentioned the usage as:

usage
is_regression = torch.Tensor([True, True, False]) # True: Regression/MeanSquaredErrorLoss, False: Classification/CrossEntropyLoss

multitaskloss_instance = MultiTaskLoss(is_regression)

So in case of classification problem I should put
is_regression = False
can clarify it a bit ?

Thank you for your comment.

Yes, that’s right.
If you have loss1, loss2, and loss3, which are cross entropy loss, cross entropy loss, and MSE loss respectively, you should pass “is_regression = torch.Tensor([True, True, F
alse])” for the constructor.

I’d like to hear whether this multi task loss implementation works in your setting, too.

Can I ask where is this equation from? I cannot find it in the paper