Merge two networks, one inside the other

Hi, I need to merge two neural networks. It is not a problem when I have to stack them one on top of the other, but I have no idea how to do it when I need to insert the first in between the layers of the second. I’ll now give a quick math explanation, and then provide the code:

I have two networks that have the following structure (activation functions are removed for simplicity):

f(x) = x W_1 W_2 W_3 and g(x) = x M_1 M_2, where W_i and M_i are weight matrices. I need to train both networks in isolation, and then merge them to obtain the following network:

F(x) = x W_1 W_2 M_1 M_2 W_3.

Now, speaking about the code, I have the following two networks:

class Net_1(nn.Module):
    def __init__(self,hidden_dim):
        super(Net_1, self).__init__()

        
        self.W1 = nn.Linear(28*28, hidden_dim)
        self.relu1 = nn.ReLU()
        self.W2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.ReLU()
        self.W3 = nn.Linear(hidden_dim, 10)
    
    def forward(self, x):

        out = self.W1(x)
        out = self.relu1(out)
        out = self.W2(out)
        out = self.relu2(out)
        out = self.W3(out)

        return out

class Net_2(nn.Module):
    def __init__(self,hidden_dim):
        super(Net_2, self).__init__()

        
        self.M1 = nn.Linear(hidden_dim, hidden_dim*2)
        self.relu = nn.ReLU()
        self.M2 = nn.Linear(hidden_dim*2, hidden_dim)
    
    def forward(self, x):

        out = self.M1(x)
        out = self.relu(out)
        out = self.M2(out)

        return out

Now, I train the two networks in isolation, and then I need to generate the following net:

class Final_Net(nn.Module):
    def __init__(self,hidden_dim):
        super(Final_Net, self).__init__()

        
        self.W1 = nn.Linear(28*28, hidden_dim)
        self.relu1 = nn.ReLU()
        self.W2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.ReLU()
        self.M1 = nn.Linear(hidden_dim, hidden_dim*2)
        self.relu = nn.ReLU()
        self.M2 = nn.Linear(hidden_dim*2, hidden_dim)
        self.relu3 = nn.ReLU()
        self.W3 = nn.Linear(hidden_dim, 10)
    
    def forward(self, x):

        out = self.W1(x)
        out = self.relu1(out)
        out = self.W2(out)
        out = self.relu2(out)
        out = self.M1(out)
        out = self.relu(out)
        out = self.M2(out)
        out = self.relu1(out)
        out = self.W3(out)

        return out

But I have no idea how to do it. It’d also be ok to have an answer does the above using sequential.

For your example,

net_1 = Net_1(hidden_dim)
net_2 = Net_2(hidden_dim)
# HERE TRAIN net_1
...
# HERE TRAIN net_2
...

# after training net_1 and net_2
net_1_params = net_1.state_dict()
net_2_params = net_2.state_dict()
assert not (net_1_params.keys() & net_2_params.keys())  # IMPORTANT
net_1_params.update(net_2_params)
final_net = Final_Net(hidden_dim)
final_net.load_state_dict(net_1_params)

NOTE: NOT ALWAYS WORK