I’m trying to implement a learnable loss weight for a multi-task setup. (mostly inspired from: How to learn the weights between two losses?)
The following snippet is a minimal example of a multi-task model and a mult-task loss with a learnable loss weight.
The issue is the following implementation only updates the loss weight but not the model weights. How can I optimize for both the loss weight and model weights?
from typing import Dict
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.classifier = nn.Linear(10, 3)
self.regressor = nn.Linear(10, 1)
def forward(self, x):
classifier_op = self.classifier(x)
regressor_op = self.regressor(x)
return {
"classification": classifier_op,
"regression": regressor_op
}
class MultiTaskLoss(nn.Module):
def __init__(self,
loss_mapping: Dict[str, nn.Module] = {"classification": nn.CrossEntropyLoss(), "regression": nn.MSELoss()}):
super().__init__()
self.loss_mapping = loss_mapping
self.loss_weights = nn.Parameter(torch.ones(2))
def forward(self,
model_output: Dict[str, torch.Tensor],
targets: Dict[str, torch.Tensor]):
losses = []
for task_name, task_op in model_output.items():
loss_fn = self.loss_mapping[task_name]
loss_value = loss_fn(task_op, targets[task_name])
losses.append(loss_value)
losses = torch.Tensor(losses)
# loss_weighting
losses = losses * self.loss_weights
return losses.sum()
model = Model()
mlt_loss = MultiTaskLoss()
params = list(model.parameters()) + list(mlt_loss.parameters())
optimizer = torch.optim.Adam(params, lr=0.1)
# train
for i in range(3):
ip = torch.rand((1, 10))
targets = {
"classification": torch.Tensor([[1, 0, 0]]),
"regression": torch.rand([1])
}
optimizer.zero_grad()
model_op = model(ip)
loss = mlt_loss(model_op, targets)
loss.backward()
optimizer.step()
# printing the parameters of the optimizer
print(f"\nEpoch {i}")
for t in optimizer.param_groups[0]['params']:
print(t.shape, t.mean().item())
The following shows all the learnable parameters applicable to the optimizer:
for t in optimizer.param_groups[0]['params']:
print(t.shape, t.requires_grad)
torch.Size([3, 10]) True < classifer weight
torch.Size([3]) True < classifier bias
torch.Size([1, 10]) True < regressor weight
torch.Size([1]) True < regressor bias
torch.Size([2]) True < loss weight
The above snippet prints the mean of the each weight after each epoch (which shows only the loss weight has been updated but not the rest of the parameters - only the last weight is updated in the output after each epoch):
Epoch 0
torch.Size([3, 10]) 0.07859472185373306
torch.Size([3]) 0.000278279185295105
torch.Size([1, 10]) -0.04856250435113907
torch.Size([1]) -0.26647862792015076
torch.Size([2]) 0.8999999761581421
Epoch 1
torch.Size([3, 10]) 0.07859472185373306 < not updated
torch.Size([3]) 0.000278279185295105 < not updated
torch.Size([1, 10]) -0.04856250435113907 < not updated
torch.Size([1]) -0.26647862792015076 < not updated
torch.Size([2]) 0.8023112416267395 < updated
Epoch 2
torch.Size([3, 10]) 0.07859472185373306 < not updated
torch.Size([3]) 0.000278279185295105 < not updated
torch.Size([1, 10]) -0.04856250435113907 < not updated
torch.Size([1]) -0.26647862792015076 < not updated
torch.Size([2]) 0.7098063230514526 < updated
How can I make both the model and loss weight parameters updated in this setting?