How to learn the weights between two losses?

In order to avoid numerical instability, we should use a variable change:

eta = log(sigma)

The new variable eta can be defined within (-oo, +oo).

Sample code:

import torch
import torch.nn as nn
import torch.optim as optim

class MultiTaskLoss(nn.Module):
    def __init__(self, model, loss_fn, eta):
        super(MultiTaskLoss, self).__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.eta = nn.Parameter(torch.Tensor(eta))

    def forward(self, input, targets):
        outputs = self.model(input)
        loss = [l(o,y).sum() for l, o, y in zip(self.loss_fn, outputs, targets)]
        total_loss = torch.Tensor(loss) * torch.exp(-self.eta) + self.eta
        return loss, total_loss.sum() # omit 1/2

class MultiTaskModel(nn.Module):
    def __init__(self):
        super(MultiTaskModel, self).__init__()
        self.f1 = nn.Linear(5, 1, bias=False)
        self.f2 = nn.Linear(5, 1, bias=False)

    def forward(self, input):
        outputs = [self.f1(x).squeeze(), self.f2(x).squeeze()]
        return outputs

mtl = MultiTaskLoss(model=MultiTaskModel(),
                    loss_fn=[nn.MSELoss(), nn.MSELoss()],
                    eta=[2.0, 1.0])

print(list(mtl.parameters()))

x = torch.randn(3, 5)
y1 = torch.randn(3)
y2 = torch.randn(3)

optimizer = optim.SGD(mtl.parameters(), lr=0.1)
optimizer.zero_grad()
loss, total_loss = mtl(x, [y1, y2])
print(loss, total_loss)
total_loss.backward()
optimizer.step()

Output:

[Parameter containing:
tensor([2., 1.], requires_grad=True), Parameter containing:
tensor([[-0.0387,  0.3287,  0.2549,  0.3336,  0.0195]], requires_grad=True), Parameter containing:
tensor([[0.2908, 0.2801, 0.1108, 0.4235, 0.0308]], requires_grad=True)]
[tensor(3.3697, grad_fn=<SumBackward0>), tensor(2.1123, grad_fn=<SumBackward0>)] tensor(4.2331, grad_fn=<SumBackward0>)
4 Likes