Forward to and backward from some intermediate output

Hi,
If the network has multiple branches and multiple outputs, can the network be trained only by the loss from one branch, i.e. the forward process only compute this branch, this output and this loss, leaving the other branches aside, and, the backward process only update the parameters of branch1?

In tensoflow, it can be easily done: if loss1 is only dependant on branch1, then define
opt1_op = optimizer.minimize(loss1)
and just run sess.run(opt1_op, feed_dict={x:x})

But in pytorch, as far as I now, the network has only one forward function, I have to make it compute all the branches. This is a waste when I only need to train one specific branch.

Hi, anyone can help me ?
Can I write my code as the following

def forward(x, flag):
    x = self.shared_branch(x)
    if flag == 0:
        output = self.branch0(x)
    elif flag == 1:
        output = self.branch1(x)
    elif flag == 2:
        output = self.branch2(x)
    else:
        output = self.branch3(x)
    return output

Hi,

In pytorch, you compute what you need to get the output in the forward pass and it will just backward on that part.
So yes you code will do what you want.

1 Like

Thanks! Pytorch is a magic!!