How to train the network with multiple branches

There is an example.

class mm(nn.Module):
    def __init__(self):
        super(mm, self).__init__()
        self.n = nn.Linear(4,3)
        self.m = nn.Linear(3,2)
        self.m2 = nn.Linear(3,4)
    def forward(self, input, input2):
        input_ = self.n(input)
        input2_ = self.n(input2)
        o1 = self.m(input_)
        o2 = self.m2(input2_)
        return o1, o2

I want to train the m and m2 with different losses. Do i need to propose two optimizers?
Please show me some examples.
Thanks

8 Likes

you can simply do:

o1, o2 = mm(input)
o = o1 + o2
# loss

## Or you can do

l1 = loss(o1, target)
l2 = loss2(o2, target2)
torch.autograd.backward([l1, l2])
9 Likes

With respect to the above code, is it ok to pass the same input to two different networks, i.e

def forward(self, input, input2):
        input_ = self.n(input)
        o1 = self.m(input_)
        o2 = self.m2(input_)
        return o1, o2

or do I need to copy the input_ Variable and then pass the copy to m2 network.

2 Likes

you can passs the same input to two different networks

3 Likes

if pass the same input to different networks with different optimizers, what is the gradient like for this input? summed gradients from two branches?

3 Likes

yes it is summed. pytorch gradients are always accumulated.

1 Like

If so, how to apply different optimizers to the two different branches? I mean, branch 1 with loss_1, branch 2 with loss_2,
loss = 2*loss_1 + 3 *loss_2
loss.backward()
optimizer.step()
how to wrap the two branches into one optimizer,while this optimizer contains actually two optim methods(Adam–>branch_1 and SGD–>branch_2)?

opt1 = optim.Adam(branch_1.parameters(), ...)
opt2 = optim.SGD(branch_2.parameters(), ...)
...
...
loss = 2*loss_1 + 3 *loss_2
loss.backward()
opt1.step()
opt2.step()
17 Likes

Can Pytorch handle backprop to separate branches if you concatenate the output of two branches into a single linear layer and then proceed to go deeper in the network until you calculate a final output?

For example:

  • Branch_1 takes channel 1 of the input image and performs convolutions.
  • Branch_2 takes channel 2 of the input image and performs convolutions.
  • Then the outputs of each branch are converted into 1D arrays and concatenated.
  • More linearly connected layers proceed the concatenation.
  • Finally some output value or class values are output by the network.

(Note: This is just an example to illustrate the use of branched networks for two input images)

Many thanks in advance!

1 Like

Yes you can do that,since the graph is build dynamically and at the time you call torch.cat, autograd will build an all together graph out of your branch-graphs and correctly handle backprop.

2 Likes

That’s great news thank you!

Is there a specific way in which to pass your batches to a branched network?

For example say your batch had dimensions [N,C,H,W] of [67,2,64,64], and your network looked like the following:

class BranchedNet(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.feature_extractor = torch.nn.Sequential(
            
            torch.nn.Conv2d(1,64,3,padding=1), 
            torch.nn.ReLU(),
            torch.nn.Conv2d(64,128,3,padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(128,256,3,padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2))
            
        self.feature_extractor2 = torch.nn.Sequential(
            
            torch.nn.Conv2d(1,64,3,padding=1), 
            torch.nn.ReLU(),
            torch.nn.Conv2d(64,128,3,padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(128,256,3,padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2))
             
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear((256*16*16)*2,264),
            torch.nn.ReLU(),      
            torch.nn.Linear(264,264),
            torch.nn.ReLU(),
            torch.nn.Linear(264,1))
        
    def forward(self,x,y):
        features = self.feature_extractor(x)
        features = features.view(int(x.size()[0]),-1)
        features2 = self.feature_extractor(y)
        features2 = features2.view(int(y.size()[0]),-1)
        grouped = torch.cat(features,features2)
        output = self.classifier(grouped)

        return output

Would you have to specify what x and y are during training?

I only ask because I have attempted this with the above structures and I get an error:

TypeError: forward() missing 1 required positional argument: ‘y’

…when simply passing model(batch) during training

EDIT:

Solved by @ptrblck -see Concatenate layer output with additional input data

1 Like

What if loss_1 and loss_2 are coming from two separate nn.Modules networks and final_loss = 2*loss_1 + 3 *loss_2. Still only one final_loss.backward() would calculate grad in both networks?

HI this might sound trivial , but how exactly are we getting branch_1 and branch_2, from a single class object, I though for a model object say model , we will have only one set of paremeters for both the networks i.e. model.parameters(), sorry if it may sound trivial I am just getting started

Can I use this idea to handle multiple predictions or branches on a single network with different loss functions? i.e

outputs = model(input)  # [32, 7, 256, 256]
o1 = outputs[:, [0,1], :,:] # [32, 2, 256, 256]
o1 = outputs[:, [2,3,4,5,6], :,:] # [32, 5, 256, 256]

# Approac I

l1 = loss1(o1, target1)
l2 = loss2(o2, target2)

if phase == 'train':
    l1.backward(retain_graph=True) 
    l2.backward()
    optimizer.step()

# Approac II

l1 = loss1(o1, target1)
l2 = loss2(o2, target2)

if phase == 'train':
    (l1+l2).backward() 
    optimizer.step()

Which of the approaches is computationally efficient and helps the model to learn better (that is able to retain what’s learnt in previous epoch)? In the case of the second approach, does reporting the values of l1 or l2 depict the real values of their individual losses after taking their accumulated sum in the next epochs?

Hi Dhawgupta, I have a very similar question as you, I would like to ask if you have found any solutions?