I’d like to ask for a suggestion from this venerable community on how I can handle a loss for a mulit-task network?
The data goes forward in parallel in an Unet-like network with two heads performing semantic segmentation with different targets (not a multi-class configuration, the tasks differ).
I was unable to find hints here and on Github.
Please suggest or forward where I can find examples.
Should I use two losses for each respective head and sum them?
But I wonder how exactly it is achieved, is it possible without writing a custom loss function?
Both heads use torch.nn.CrossEntropyLoss(), can I add them explicitly or should I write a custom implementation for the loss?
Thank you in advance, I apologize if I violated any rules with my first question (I have an impression the topics are not categorized well here).
When you now optimize, you will train head-1 and head-2 (in a sense)
separately to perform well on task-1 and task-2, respectively. The shared
part of the network, upstream of the two heads, will be trained jointly to
perform reasonably well on both tasks, making trade-offs in its trained
weights that might, for example, make the task-1 performance a little
worse if doing so yields are larger task-2 performance gain.
Note the weight_task_2 parameter I included in my code fragment. It
allows you to tune your model’s relative performance on the two tasks.