Transfer weights from classifier encoder for image segmentation task


for a task, I have implemented a segmentation model using Unet from scratch that can successfully segment five features (e.g. A, B, C, D, E). Using the library segmentation models I also tried to compare different backbones with Imagenet weights.

Now, matching this domain, I found a large (150k) and public Kaggle dataset that does not label the above classes at pixel level, but classifies the image data into different classes (for classification tasks). So instead of Imagenet, I would also like to try to transfer the weights of this dataset.

Could I somehow store the encoder or the weights of the encoder (e.g. ResNet) from a simple classifier to load them into my Unet and test if I get better results using this tranfer learning approach? Code examples are very welcome.

Kind regards

It depends on your UNet implementation, but I guess you could reuse the encoder for another classification model (you might already do the opposite of reusing a classification model in your UNet).
To do so, reuse the encoder from the UNet, add a classification head to the model, and train it using the large classification dataset. Afterwards, load the pretrained state_dict of the encoder only to your UNet encoder.