Combining Trained Models in PyTorch

Hi all,
I’m currently working on two models that train on separate (but related) types of data. I’d like to make a combined model that than take in an instance of each of the types of data, runs them through each of the models that was pre-trained individually, and then has a few feed-forward layers at the top that process the combined result of the two individual models.

So far, I know that I can modify forward to take in both inputs, so I just have copied the architectures of my individual models into the combined one, process them both separately by running them through the correct layers in forward(), and then combining the results as I described. What I can’t figure out how to do is used the pre trained models, rather than having to use the same architecture from scratch, in the combined model.
I’d really appreciate any input! Thanks!

18 Likes

The logic you described seems to be reasonable.
I’m not sure I understand the issue correctly, but it seems you are not sure, how to restore the pretrained models and use them in the new model.
If that’s the case, you could instantiate both pretrained models, load the state_dicts, and pass them to your new model.
I’ve created a small example, since it might be easier to just see the code:

class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x
    

class MyModelB(nn.Module):
    def __init__(self):
        super(MyModelB, self).__init__()
        self.fc1 = nn.Linear(20, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x


class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        self.classifier = nn.Linear(4, 2)
        
    def forward(self, x1, x2):
        x1 = self.modelA(x1)
        x2 = self.modelB(x2)
        x = torch.cat((x1, x2), dim=1)
        x = self.classifier(F.relu(x))
        return x

# Create models and load state_dicts    
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))
modelB.load_state_dict(torch.load(PATH))

model = MyEnsemble(modelA, modelB)
x1, x2 = torch.randn(1, 10), torch.randn(1, 20)
output = model(x1, x2)
57 Likes

Thanks so much! Exactly what I needed; I wasn’t aware I could use whole models in the forward (but I guess that makes sense, as they all subclass nn.Module).

1 Like

@ptrblck- Can you show an example of how would you do the same thing for let’s say densenet161 and resnet152? Really appreciate your help.

the example above is as close as it gets to simply changing ModelA / ModelB with a ResNet / Densenet.

If you are doing a homework problem, I dont think we’re here to give you the exact solution and give you free points :stuck_out_tongue:

12 Likes

@smth - Not a homework problem, just trying to learn something completely out of my domain.

If I understand it correctly, if modelA is resnet152, I can change self.fc1 to self.fc1 = nn.Linear(2048, x) where 2048 is in_features for resnet152 and x = out_features one desires. We can do the same for modelB. But in self.classifier = nn.Linear(4,2) where did those number 4 and 2 come from.

I understand these are very trivial for a man of your stature in programming, but for nubes…I appreciate your prompt response and making sure you are not solving someone else’s homework.

2 Likes

The number of input features for self.classifier is defined by the sum of the output features of both “pre-ensemble” models.
As you can see, ModelA and ModelB both output 2 values (from self.fc1). These features will be concatenated in MyEnsemble and passed to self.classifier, which takes 4 input features.

6 Likes

Thanks @ptrblck for clarification. Got it finally.

@ptrblck If I call eval() on a MyEnsemble object, will it automatically call eval on MyModelA and MyModelB?

Thank you for your answer! :100:

1 Like

Yes.
eval recursively calls eval on its children nodes.
Ref- https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.eval

5 Likes

@ptrblck, for you example above. Is the weights of ModelA and ModelB is frozen or trained with classifier?

All parameters will be trained end-to-end in the example.

@ptrblck, thank you for the toy example. How would you change the script if modelB 's input was taken from one of the layers output of modelA. And the ensemble output has 2 heads.In the modification below, fc1 of modelA is ran twice (not efficient):

class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(20, 10)
        self.fc2 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x
    

class MyModelB(nn.Module):
    def __init__(self):
        super(MyModelB, self).__init__()
        self.fc1 = nn.Linear(10, 2)
        
    def forward(self, x):
        x = self.fc1(x)
        return x

class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        
    def forward(self, x1, x2):
        head1 = self.modelA(x1)
        head2 = self.modelB(self.modelA.fc1(x1))
        x = torch.cat((head1, head2), dim=1)
        return x

Thanks!

1 Like

You could return both values in the forward method of ModelA:

class MyModelA(nn.Module):
    def __init__(self):
        super(MyModelA, self).__init__()
        self.fc1 = nn.Linear(20, 10)
        self.fc2 = nn.Linear(10, 2)
        
    def forward(self, x):
        x1 = self.fc1(x)
        x2 = self.fc2(x1)
        return x1, x2

...

class MyEnsemble(nn.Module):
    def __init__(self, modelA, modelB):
        super(MyEnsemble, self).__init__()
        self.modelA = modelA
        self.modelB = modelB
        
    def forward(self, x1, x2):
        head1a, head1b = self.modelA(x1)
        head2 = self.modelB(head1a)
        x = torch.cat((head1b, head2), dim=1)
        return x
5 Likes

I see! That requires modifying the forward method of modelA?
Here is a few words about what I am trying to do.
I have ModelA :
[input/image -> encoder1 -> header1 -> output1].
The model was trained and has pretty good accuracy.
Now I want to add a 2nd task.
For that purpose, I run the entire dataset through ModelA and for each example, save the activation at the last layer of the encoder as numpy array file.
Those numpy arrays are then used as the inputs to modelB [input2 -> header2 -> output2]. Then run optimization and save the weights of modelB.

Finally, I want to combine modelA and modelB into a combo-model with 2 heads:

combo_model:  [input/image -> encoder1 -> header1 -> output1 
                                      \-> header2 -> output2

This combo_model is only used at test time.
(I understand I could start with the combo model architecture and train the modelA part with freezing modelB branch and then optimize modelB with freezing ModelA. IMHO, the approach is more cumbersome!)

You don’t need to retrain the models, but could just redefine the forward of modelA and load it’s state_dict.
The state_dict contains only the parameters of the model, so you are free to manipulate the actual forward method as you wish, since the computation graph is created dynamically. :slight_smile:

4 Likes

OK, thank you for your answers.

how will one continue creating these ensembles, of 2 to 4 to 8 neural networks and so on?

That’s helped a lot. Thanks!

1 Like

Hi, i’m working with the exact same structure of your example above. How can i keep frozen the weights of the two models A and B (pre-trained) and train only the classifier layer of the ensemble model?

thanks