Sirui_Li
(Sirui Li)
1
I built up an ensemble model:
class ModelA(nn.Modue):
def __init__(self):
super(ModelA, self).__init__()
self.encoder = SentenceTransformer("./paraphrase-distilroberta-base-v1")
def forward(self, text):
return torch.tensor(self.encoder.encode(text))
class ModelB(nn.Module):
def __init__(self, option):
self.option =option
self.hidden2output = nn.Linear(self.option.hidden_dim, self.option.output_dim)
def forward(self, emb, head_id, actual):
output = self.hidden2output(emb)
....
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA = ModelA
self.modelB = ModelB
def forward(self, a, b, c):
emb = self.modelA(a)
output = self.modelB(emb, b, c)
modelA = ModelA()
modelB = ModelB(option)
model = MyEnsemble(modelA, modelB)
for batch in all_batch:
model(batch[0], batch[1], batch[2])
The error I got:
TypeError: _init_() takes 1 positional argument but 2 were given
If I just test modelA or modelB, there is no error
for batch in all_batch:
modelA(batch[0])
I think your MyEnsemble
forward taking three arguments a,b,c but only two are passed.
import torch.nn as nn
class ModelA(nn.Module):
def __init__(self,n_channels):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
#DO ANYTHING HERE
return x
class ModelB(nn.Module):
def __init__(self,n_channels1,n_channels2,n_channels3):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x,y,z):
#DO ANYTHING HERE
# x = self.pool(F.relu(self.conv1(x)))
# x = self.pool(F.relu(self.conv2(x)))
# x = torch.flatten(x, 1) # flatten all dimensions except batch
# x = F.relu(self.fc1(x))
# x = F.relu(self.fc2(x))
# x = self.fc3(x)
return x
class MyEnsemble(nn.Module):
def __init__(self):
super(MyEnsemble, self).__init__()
self.modelA = ModelA(1)
self.modelB = ModelB(1,1,1)
def forward(self, a, b, c):
emb = self.modelA(a)
print("emb",emb)
output = self.modelB(emb, b, c)
return output
model = MyEnsemble()
model(1, 2, 3)
Sample code
Sirui_Li
(Sirui Li)
4
Thanks a lot, mate! It helps.