Multiple head loss implementation

Greetings!

I’d like to ask for a suggestion from this venerable community on how I can handle a loss for a mulit-task network?
The data goes forward in parallel in an Unet-like network with two heads performing semantic segmentation with different targets (not a multi-class configuration, the tasks differ).
I was unable to find hints here and on Github.
Please suggest or forward where I can find examples.
Should I use two losses for each respective head and sum them?
But I wonder how exactly it is achieved, is it possible without writing a custom loss function?
Both heads use torch.nn.CrossEntropyLoss(), can I add them explicitly or should I write a custom implementation for the loss?
Thank you in advance, I apologize if I violated any rules with my first question (I have an impression the topics are not categorized well here).

1 Like

Hi Iaroslav!

Yes, it is perfectly reasonable to add the two losses together, and
there is no need to write a custom loss function.

So, something like:

loss_task_1 = torch.nn.CrossEntropyLoss() (pred_head_1, target_task_1)
loss_task_2 = torch.nn.CrossEntropyLoss() (pred_head_2, target_task_2)
total_loss = loss_task_1 + weight_task_2 * loss_task_2
total_loss.backward()

When you now optimize, you will train head-1 and head-2 (in a sense)
separately to perform well on task-1 and task-2, respectively. The shared
part of the network, upstream of the two heads, will be trained jointly to
perform reasonably well on both tasks, making trade-offs in its trained
weights that might, for example, make the task-1 performance a little
worse if doing so yields are larger task-2 performance gain.

Note the weight_task_2 parameter I included in my code fragment. It
allows you to tune your model’s relative performance on the two tasks.

Best.

K. Frank

3 Likes

Thank you, K. Frank!