Hi, I need to merge two neural networks. It is not a problem when I have to stack them one on top of the other, but I have no idea how to do it when I need to insert the first in between the layers of the second. I’ll now give a quick math explanation, and then provide the code:
I have two networks that have the following structure (activation functions are removed for simplicity):
f(x) = x W_1 W_2 W_3 and g(x) = x M_1 M_2, where W_i and M_i are weight matrices. I need to train both networks in isolation, and then merge them to obtain the following network:
F(x) = x W_1 W_2 M_1 M_2 W_3.
Now, speaking about the code, I have the following two networks:
class Net_1(nn.Module):
def __init__(self,hidden_dim):
super(Net_1, self).__init__()
self.W1 = nn.Linear(28*28, hidden_dim)
self.relu1 = nn.ReLU()
self.W2 = nn.Linear(hidden_dim, hidden_dim)
self.relu2 = nn.ReLU()
self.W3 = nn.Linear(hidden_dim, 10)
def forward(self, x):
out = self.W1(x)
out = self.relu1(out)
out = self.W2(out)
out = self.relu2(out)
out = self.W3(out)
return out
class Net_2(nn.Module):
def __init__(self,hidden_dim):
super(Net_2, self).__init__()
self.M1 = nn.Linear(hidden_dim, hidden_dim*2)
self.relu = nn.ReLU()
self.M2 = nn.Linear(hidden_dim*2, hidden_dim)
def forward(self, x):
out = self.M1(x)
out = self.relu(out)
out = self.M2(out)
return out
Now, I train the two networks in isolation, and then I need to generate the following net:
class Final_Net(nn.Module):
def __init__(self,hidden_dim):
super(Final_Net, self).__init__()
self.W1 = nn.Linear(28*28, hidden_dim)
self.relu1 = nn.ReLU()
self.W2 = nn.Linear(hidden_dim, hidden_dim)
self.relu2 = nn.ReLU()
self.M1 = nn.Linear(hidden_dim, hidden_dim*2)
self.relu = nn.ReLU()
self.M2 = nn.Linear(hidden_dim*2, hidden_dim)
self.relu3 = nn.ReLU()
self.W3 = nn.Linear(hidden_dim, 10)
def forward(self, x):
out = self.W1(x)
out = self.relu1(out)
out = self.W2(out)
out = self.relu2(out)
out = self.M1(out)
out = self.relu(out)
out = self.M2(out)
out = self.relu1(out)
out = self.W3(out)
return out
But I have no idea how to do it. It’d also be ok to have an answer does the above using sequential.