I have such model. How to train such model because of if condion. I want to optimizer.step with modelA, modelB and modelA+ModelB(MyEnsemble(). Its creating batch issue.
class MyEnsemble(nn.Module):
def __init__(self, modelA, modelB):
super(MyEnsemble, self).__init__()
self.modelA = modelA
self.modelB = modelB
def forward(self, x1):
output = self.modelA(x1)
print(output)
if output[0]>0.5:
out = self.modelB(x1)
return out