The logic you described seems to be reasonable.
I’m not sure I understand the issue correctly, but it seems you are not sure, how to restore the pretrained models and use them in the new model.
If that’s the case, you could instantiate both pretrained models, load the state_dicts
, and pass them to your new model.
I’ve created a small example, since it might be easier to just see the code:
class MyModelA(nn.Module):
def __init__(self):
super(MyModelA, self).__init__()
self.fc1 = nn.Linear(10, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyModelB(nn.Module):
def __init__(self):
super(MyModelB, self).__init__()
self.fc1 = nn.Linear(20, 2)
def forward(self, x):
x = self.fc1(x)
return x
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA = modelA
self.modelB = modelB
self.classifier = nn.Linear(4, 2)
def forward(self, x1, x2):
x1 = self.modelA(x1)
x2 = self.modelB(x2)
x = torch.cat((x1, x2), dim=1)
x = self.classifier(F.relu(x))
return x
# Create models and load state_dicts
modelA = MyModelA()
modelB = MyModelB()
# Load state dicts
modelA.load_state_dict(torch.load(PATH))
modelB.load_state_dict(torch.load(PATH))
model = MyEnsemble(modelA, modelB)
x1, x2 = torch.randn(1, 10), torch.randn(1, 20)
output = model(x1, x2)