Implementing Multi-Task Learning For Classification and Segmentation

Hello every one. I have some question that I think are more related with Deep-Learning in general than it is with Pytorch. I think they are pretty simple, any opinion is welcome.

I am dealing with a problem where I use a segmentation network, and based on the segmentation I apply some simple rules and make a binary classification. What I am trying to do is to train a classification network that does the binary classification at once.
The advantage would be that the classification network is smaller and faster, it is easier to increase the dataset, and the post-processing isn’t necessary.
The disadvantage is that classification network usually need more data to get better generalization.

At the moment my segmentation dataset has 3000 images, and the classification has 33000 images. Just a small intersection of those images have both types of ground truth.

The issues I am facing is that the segmentation model generalizes better for new data (data that isn’t in any of my datasets. Training, validation and testing datasets), and the classification is struggling with new data.

What I want to do is to try multi-task learning, hopping that the segmentation task will help the classification task to “understand the images correctly”, getting this way a better generalization.

How exactly can this training be done? I can think in some ways, but I don’t know exactly the consequences of each approach. If someone could give some hints or help me to understand the nuances of each approach, I would be very grateful.

Here are the approaches I think:

  1. If my entire dataset had ground truth for both tasks, I would simply train this way:
    mtl_training
  2. Since only a small intersection of my dataset has both ground truth, I will have to switch between tasks. This way:
    mtl_training2
    mtl_training3
    But how exactly I do it? Train one entire epoch on segmentation dataset, then one epoch on the segmentation? Should I keep bouncing with one batch of classification and then one of segmentation? Should I feed one batch of classification, keep the loss, then feed another different batch of segmentation (different images), sum the losses and then backpropagate?

I think some approaches will require more GPU memory than the others, requiring me to use smaller batch sizes.

Thank you so much for everyone that spent the the time to read this huge topic.
I will try some of these approaches and will share my finding here as well, any help is welcome.

2 Likes

Multi-task learning using classification and segmentation together looks like an amazing approach.

I’m curious about this subject. I’m replying here to see if it helps to get the attention of someone who can explain more about it.

Have you tried the training with the 3000 that match between classification and segmentation already, if so, how it went?

Hi, interesting problem. Did you find a solution?

Hello @sidb23.
Had forgot about this topic.
The implementation I did two years ago was more or less the last one I mentioned.
I just “sum” both datasets, and let the sampler form batches from the pool of images.
For the images in the batch that contain classification ground truth, I compute the classification loss, and the same for the segmentation. Then I just make the average of both losses using an alpha weight to give more attention to one of the tasks.
In my case, the classification dataset is way bigger. It means that most of my batches will contain only classification samples. This can be a problem, because the model can ignore the segmentation task.
I also tried to train only the segmentation part for some epochs, and then introduce the other task.
Now days I think that the best way to go is to make sure your batches always have samples for all tasks. This will force the network to learn all of them.
Try making something like this:
Classification dataset has 900 images
Segmentation dataset has 100 images.
Your batch size is 10.
Make sure every batch has 9 classification images and 1 segmentation.
Then play with the weight you give to each loss when training.

In the end, my classification model got much more robust. Generalized much better to new data.
Hope this helps you.

1 Like