As you can see. The middle outputs of one branche is the input of another branches.
class MyNet(nn.Module):
def __init__(self):
self.convnet = CNN()
self.encoder1 = Encoder()
self.discriminator = Discriminator()
self.encoder2 = Encoder()
self.classifier = Classifier()
def forward(self,x):
#...
x = self.convnet(x)
out = self.encoder1(x)
out1 = self.discriminator(out)
out2 = self.encoder(2)
out2 = torch.cat(out,out2)
out2 = self.classifier(out2)
return out1,out2
May be I should write the Net separately:
class Branch1(nn.Module):
def __init__(self):
# ...
self.convnet = CNN()
self.encoder = Encoder()
self.discriminator = Discriminator()
def forward(self,x):
x = self.convnet(x)
out = self.encoder(x)
out1 = self.discriminator(out)
return out,out1
class Branch2(nn.Module):
def __init__(self):
# ...
self.convnet = CNN()
self.encoder = Encoder()
self.classifier = classifier()
def forward(self,x,branch):
x = self.convnet(x)
x= torch.cat(x,branch)
out = self.encoder(x)
out = self.classifier(out)
return out,out1
Then, How to handle the 2 different parts with different loss?