Using a custom model as feature extractor

I’ve a custom model which I’ve saved as model_dict (which I saved as a .pth file) after training. Now I want to use it as feature extractor. How can I do this?

network architecture

Please see the model architecture from the above link.

Specifically I want to extract the features of the colored layer. can someone point out an example or a tutorial to achieve the same. Thanks!

Based on your image, it looks like your model has two different feature extractors, which are concatenated and passed into another module.
Depending on your model definition, you could just change the forward pass and return the features from the first feature extractor.
I’ve created a simple example. The commented code would call the complete model.

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.feature1 = nn.Sequential(
            nn.Linear(100, 20),
            nn.ReLU(),
            nn.Linear(20, 5)
        )
        self.feature2 = nn.Sequential(
            nn.Linear(100, 20),
            nn.ReLU(),
            nn.Linear(20, 5)
        )
        
        self.main_block = nn.Sequential(
            nn.Linear(10, 5),
            nn.ReLU(),
            nn.Linear(5, 2),
            nn.LogSoftmax(dim=1)
        )
        
    def forward(self, x):
        x1 = self.feature1(x)
        x2 = self.feature2(x)
        
        #x = torch.cat((x1, x2), dim=1)
        #x = self.main_block(x)
        return x1 #x

model = MyModel()
x = torch.randn(1, 100)        
output = model(x)

If it’s not possible to change the forward pass for some reason, you could try to register a forward hook on the specific layer. Here is a small example.

Thanks ptrblck. Yeah, it has has two different feature extractors, which are concatenated and passed into another module. I trained such system and want to load the weights from the model I saved. I can change the forward pass, but like I said I don’t want to change weights. How can I achieve it?

Thanks.

You could just load the state_dict into your model, probably set it to evaluation with model.eval(), and change the forward pass to return your desired activation.

To save and load a model, have a look at the Serialization semantics.

Thanks ptrblck. I was trying with the same. model.eval(). I think I missed something earlier. Earlier was using load_state_dict in a loop and tried model.eval(). Will see if I missed some trick. Thanks again.