How to train a network with different branches (not parallel)?


As you can see. The middle outputs of one branche is the input of another branches.

class MyNet(nn.Module):
   def __init__(self):
        self.convnet = CNN()
        self.encoder1 = Encoder()
        self.discriminator = Discriminator()
        self.encoder2 = Encoder()
        self.classifier = Classifier()
   def forward(self,x):
       #...
       x = self.convnet(x)
       out = self.encoder1(x)
       out1 = self.discriminator(out)
       out2 = self.encoder(2)
       out2 = torch.cat(out,out2)
       out2 = self.classifier(out2)
       return out1,out2

May be I should write the Net separately:

class Branch1(nn.Module):
   def __init__(self):
       # ...
        self.convnet = CNN()
        self.encoder = Encoder()
        self.discriminator = Discriminator()
   def forward(self,x):
       x = self.convnet(x)
       out = self.encoder(x)
       out1 = self.discriminator(out)
       return out,out1
class Branch2(nn.Module):
   def __init__(self):
       # ...
        self.convnet = CNN()
        self.encoder = Encoder()
        self.classifier = classifier()
   def forward(self,x,branch):
       x = self.convnet(x)
       x= torch.cat(x,branch)
       out = self.encoder(x)
       out = self.classifier(out)
       return out,out1

Then, How to handle the 2 different parts with different loss?

Hi,

I think the first option is the right one (i corrected a typo):

class MyNet(nn.Module):
   def __init__(self):
        self.convnet = CNN()
        self.encoder1 = Encoder()
        self.discriminator = Discriminator()
        self.encoder2 = Encoder()
        self.classifier = Classifier()
   def forward(self,x):
       #...
       x = self.convnet(x)
       out = self.encoder1(x)
       out1 = self.discriminator(out)
       out2 = self.encoder2(x)
       out2 = torch.cat(out,out2)
       out2 = self.classifier(out2)
       return out1,out2

and regarding loss you just write your loss with respect to both output like:

loss = criterion1(out1, ground_truth1) + alpha * criterion2(out2, ground_truth2)
loss.backward()
opt.step()

And it will work fine!

Hope this help.