Ensemble Model Training

I built up an ensemble model:

class ModelA(nn.Modue):
      def __init__(self):
           super(ModelA, self).__init__()
           self.encoder = SentenceTransformer("./paraphrase-distilroberta-base-v1")
      def forward(self, text):
          return torch.tensor(self.encoder.encode(text)) 

class ModelB(nn.Module):
      def __init__(self, option):
         self.option =option
         self.hidden2output = nn.Linear(self.option.hidden_dim, self.option.output_dim)
      def forward(self, emb, head_id, actual):
         output = self.hidden2output(emb)
         ....

class MyEnsemble(nn.Module):
       def __init__(self, modelA, modelB):
           super(MyEnsemble, self).__init__()
           self.modelA = ModelA
           self.modelB = ModelB
       def forward(self, a, b, c):
          emb = self.modelA(a)
         output = self.modelB(emb, b, c)
modelA = ModelA()
modelB = ModelB(option)
model = MyEnsemble(modelA, modelB)
for batch in all_batch:
     model(batch[0], batch[1], batch[2])

The error I got:

TypeError: _init_() takes 1 positional argument but 2 were given

If I just test modelA or modelB, there is no error

for batch in all_batch:
    modelA(batch[0])

I think your MyEnsemble forward taking three arguments a,b,c but only two are passed.

import torch.nn as nn

class ModelA(nn.Module):
    def __init__(self,n_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        #DO ANYTHING HERE
        return x 

class ModelB(nn.Module):
    def __init__(self,n_channels1,n_channels2,n_channels3):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x,y,z):
        #DO ANYTHING HERE
#         x = self.pool(F.relu(self.conv1(x)))
#         x = self.pool(F.relu(self.conv2(x)))
#         x = torch.flatten(x, 1) # flatten all dimensions except batch
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
        return x
class MyEnsemble(nn.Module):
    def __init__(self):
        super(MyEnsemble, self).__init__()
        self.modelA = ModelA(1)
        self.modelB = ModelB(1,1,1)
    def forward(self, a, b, c):
        emb = self.modelA(a)
        print("emb",emb)
        output = self.modelB(emb, b, c)
        
        return output
model = MyEnsemble()
model(1, 2, 3)

Sample code

Thanks a lot, mate! It helps.