Combining two pretrained models for Unet


I have a UNet that does binary semantic segmentation and I wanted to extend it to detect multiple classes. The first idea that I got was to train a model for each class and then when I would do the prediction, I would use a combined model between all pretrained models. Any idea on how to approach this?

Also, if this is not a good idea and if you have a better approach please let me know :smiley:

You could try to use multiple models, where each model predicts a single class vs. others, but would then need to decide how to use the results. I.e. if you are working on a multi-label classification (zero, one, or multiple classes can be active for each sample), you could directly use the output. On the other hand, if you are working on a multi-class classification, you would have to decide, if e.g. only the highest prediction should be used as the class prediction.

Usually you would just use a single model and let it predict multiple classes directly.

Thank you for your answer! I got the part with the multi-label classification, but I am more interested in the multi-class classification and I didn’t understand it. I have to train multiple models to predict one class vs others and then choose the best predictor for every class?

This would be your proposed approach, if I understand it correctly, and I just tried to clarify the workflow.
For a “standard” multi-class classification, you would just use an output layer which returns nb_classes logits (e.g. via nn.Linear(in_features, nb_classes)).

1 Like