Transfer learning from one network architecture to another network architecture

Hi. I have one model that is efficientnet it has some layers and one average pooling layer and last layers are some fully connected layers. This is for image classification you can assume we have 5 classes here. I trained the model with 1000, 3d MRI images and I saved the model by load.state_dict. Then I have another image classification model that it has 10 classes and I want to train it by 200, 3d images. But I wanted to use the first model weights for the second model before the fully connected layers. How is it possible?
I know until here
I have to make an object of first classification
Classification1 = Classfication1_Model()
Load torch weights
Classification1.load_state_dict(torch.load(path_to_state_dict))
I have to make an object of second classification
But Here it is the exact question I have. How can I use pretrained model weights for new classification? Should I freeze layers before average pooling layer and how can I add weights to model?

I am assuming that the network architecture in both models is exactly the same except for the classification layer. In that case, you can load the weights using the load_state_dict function, and then use Classification.linear=nn.Linear(input_feature,200) to change the classification layer and fine-tune using the previous weights you got.

How can I fill input_feature? There is no need to fill it? and instead of 200 you mean 10? Beacuse in second I have 10 classes.
But I have one important question. If I had one segmentation model and wanted to use the weights of that model for classification model what should I do? How can I use the weights in this situation?

Yes, basically replace the last output of linear layer with your new number of classes.

A segmentation model has an encoder decoder architecture, you can use the weights of the encoder part, add a linear layer on top for classification and use it for classification as well.

1 Like

I am familiar with the theory of it but I don’t know how should I apply it?
Can you please explain it with code?
My questions are exactly these ones:
1)Because when we save a model with pth extension all weights of model were saved but I need just the encoder part weights how can I obtain it?
2)How can I make a classification model that has only the encoder part of segmentation(for example Unet3d) and add a layer of classification on top of it?

@maryma sure, I can provide a minimal code that can help you in understanding how to use the saved model.

We first define the auto encoder model which does both encoding and decoding:

class encoderDecoder(nn.Module):
    def __init__(self):
        super(encoderDecoder, self).__init__()
        self.encoder = nn.Sequential(#encoder layers go here)
        self.decoder = nn.Sequential(#decoder layers go here)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

We can save its weight using:

model = encoderDecoder()
torch.save(model.state_dict(), 'autoencoder.pt')

Next we define a model but with only encoder layers:

class encoderModel(nn.Module):
    def __init__(self):
        super(encoderModel, self).__init__()
        self.encoder = nn.Sequential(nn.Conv2d(3,16,3))
        
    def forward(self, x):
        x = self.encoder(x)
        return x

Now we can load the weights using:

model2 = encoderModel()
model2.load_state_dict(torch.load('autoencoder.pt'), strict=False)

You will get a warning that says:

_IncompatibleKeys(missing_keys=[], unexpected_keys=['decoder.0.weight', 'decoder.0.bias'])

But it should be fine since we do not need the decoder layers here. So this clears how we can just use the encoder part for our model.

For this part we will use the same encoder decoder model that we used above:

class encoderDecoder(nn.Module):
    def __init__(self):
        super(encoderDecoder, self).__init__()
        self.encoder = nn.Sequential(#encoder layers go here)
        self.decoder = nn.Sequential(#decoder layers go here)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

We save this model in the same way:

model = encoderDecoder()
torch.save(model.state_dict(), 'autoencoder.pt')

For the classification part we need to add a linear layer (or few layers if you want). This can be done as follows:

class classificationModel(nn.Module):
    def __init__(self):
        super(classificationModel, self).__init__()
        self.encoder = nn.Sequential(#encoder layers go here)
        self.fc = nn.Sequential(nn.Flatten(), nn.Linear(#add input number of features here, #add number of classes here))
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.fc(x)
        return x

And you can load weights just as we did before using:

model2 = classificationModel()
model2.load_state_dict(torch.load('autoencoder.pt'), strict=False)

This time you will get another warning which says that it cannot find state dict keys for fc layer, but that is fine as the original model did not have those layers. This way you can use the encoder part to train a classifier.

1 Like

Thanks for your complete answere