Training a multi-task model with a learnable loss component weight

I’m trying to implement a learnable loss weight for a multi-task setup. (mostly inspired from: How to learn the weights between two losses?)

The following snippet is a minimal example of a multi-task model and a mult-task loss with a learnable loss weight.

The issue is the following implementation only updates the loss weight but not the model weights. How can I optimize for both the loss weight and model weights?

from typing import Dict
import torch
import torch.nn as nn

class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.classifier = nn.Linear(10, 3)
        self.regressor = nn.Linear(10, 1)
        
    def forward(self, x):
        classifier_op = self.classifier(x)
        regressor_op = self.regressor(x)

        return {
            "classification": classifier_op,
            "regression": regressor_op
        }

class MultiTaskLoss(nn.Module):

    def __init__(self, 
                 loss_mapping: Dict[str, nn.Module] = {"classification": nn.CrossEntropyLoss(), "regression": nn.MSELoss()}):
        super().__init__()
        self.loss_mapping = loss_mapping
        self.loss_weights = nn.Parameter(torch.ones(2))

    def forward(self,
                model_output: Dict[str, torch.Tensor], 
                targets: Dict[str, torch.Tensor]):
        
        losses = []
        
        for task_name, task_op in model_output.items():
            loss_fn = self.loss_mapping[task_name]
            loss_value = loss_fn(task_op, targets[task_name])
            losses.append(loss_value)

        losses = torch.Tensor(losses)

        # loss_weighting
        losses = losses * self.loss_weights

        return losses.sum()

model = Model()
mlt_loss = MultiTaskLoss()

params = list(model.parameters()) + list(mlt_loss.parameters())
optimizer = torch.optim.Adam(params, lr=0.1)

# train
for i in range(3):

    ip = torch.rand((1, 10))

    targets = {
        "classification": torch.Tensor([[1, 0, 0]]),
        "regression": torch.rand([1])
    }

    optimizer.zero_grad()
    model_op = model(ip)
    loss = mlt_loss(model_op, targets)
    loss.backward()
    optimizer.step()
    
    # printing the parameters of the optimizer
    print(f"\nEpoch {i}")
    for t in optimizer.param_groups[0]['params']:
        print(t.shape, t.mean().item())

The following shows all the learnable parameters applicable to the optimizer:

for t in optimizer.param_groups[0]['params']:
    print(t.shape, t.requires_grad)

torch.Size([3, 10]) True             < classifer weight
torch.Size([3]) True                 < classifier bias
torch.Size([1, 10]) True             < regressor weight
torch.Size([1]) True                 < regressor bias
torch.Size([2]) True                 < loss weight

The above snippet prints the mean of the each weight after each epoch (which shows only the loss weight has been updated but not the rest of the parameters - only the last weight is updated in the output after each epoch):


Epoch 0
torch.Size([3, 10]) 0.07859472185373306
torch.Size([3]) 0.000278279185295105
torch.Size([1, 10]) -0.04856250435113907
torch.Size([1]) -0.26647862792015076
torch.Size([2]) 0.8999999761581421               

Epoch 1
torch.Size([3, 10]) 0.07859472185373306          < not updated
torch.Size([3]) 0.000278279185295105             < not updated
torch.Size([1, 10]) -0.04856250435113907         < not updated
torch.Size([1]) -0.26647862792015076             < not updated
torch.Size([2]) 0.8023112416267395               < updated

Epoch 2
torch.Size([3, 10]) 0.07859472185373306          < not updated
torch.Size([3]) 0.000278279185295105             < not updated
torch.Size([1, 10]) -0.04856250435113907         < not updated
torch.Size([1]) -0.26647862792015076             < not updated
torch.Size([2]) 0.7098063230514526               < updated

How can I make both the model and loss weight parameters updated in this setting?

Hi akt42!

This line breaks the computation graph that links losses back to your
model parameters. So when you backpropagate, the .grad properties
of your model parameters remain None and your optimizer doesn’t update
the parameters.

Change the line to losses = torch.cat (losses) (or something similar,
depending on what shape you need losses to have).

(torch.Tensor() constructs a new tensor that has requires_grad = False.
Furthermore, losses starts out as a list of two single-value tensors (that do
happen to have requires_grad = True). But they get converted to python
scalars along the way and those have no notion of requires_grad.)

Best.

K. Frank

1 Like

Hi Frank,

Thank you! I was able to get it working as follows based on your observation.

class MultiTaskLoss(nn.Module):

    def __init__(self, 
                 loss_mapping: Dict[str, nn.Module] = {"classification": nn.CrossEntropyLoss(), "regression": nn.MSELoss()}):
        super().__init__()
        self.loss_mapping = loss_mapping
        self.loss_weights = nn.Parameter(torch.ones(2))

    def forward(self,
                model_output: Dict[str, torch.Tensor], 
                targets: Dict[str, torch.Tensor]):
        
        losses = []
        
        for task_name, task_op in model_output.items():
            loss_fn = self.loss_mapping[task_name]
            loss_value = loss_fn(task_op, targets[task_name])
            losses.append(loss_value.unsqueeze(dim=0))

        losses = torch.cat(losses)

        # loss_weighting
        losses = losses * self.loss_weights

        return losses.sum()