How do I update my network with different losses for each component?


The image is a simplified version of my network. I want to update Network1 with Loss_Network1 and Network2 with Loss_Network2. However, when I do backward(), the grads will be calculated for them regardless (ex: the grads for Loss_Network1 is added onto the .grad of Network2). The only approach I can think of now is:
Loss_Network1.backward()
optimizer_Network2.zero_grad()
copy the .grad for Network1
Loss_Network2.backward()
optimizer_Network1.zero_grad()
copy back the .grad for Network1
optimizer_Network1.step()
optimizer_Network2.step()

and also, doing optimizer_Network1.step() before Loss_Network2.backward() although may seem obvious but isn’t an option because that way it will raise: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

Please let me know if there’s is a simple way to maybe freeze .grad in backward() so that I won’t need to copy the .grad s. Thanks!!

1 Like

Hi,

I’m afraid this is not easy to do. And what you’re doing right now is pretty close to optimal.
The only other approach would be to do two forwards, one for lossnet1 and one for lossnet2 and only update the corresponding net before doing the other’s backward. But this will be slower I’m afraid.

2 Likes

Thanks for the reply! I just thought of another possible solution. Can we do something like stopping the back propagation at some point? like stop it in between Net1 and Net2?

I have just implemented the copy and past grad function.

def copy_grad(optimizer):
    grad_list = []
    for ppp in optimizer.param_groups[0]['params']:
        gradd = ppp.grad.clone()
        grad_list.append(gradd)

    return grad_list

def paste_grad(optimizer, grad_list):
    for index, ppp in enumerate(optimizer.param_groups[0]['params']):
        ppp.grad = grad_list[index]
        
    return

You can, but you will still need 2 calls to backprop. So i don’t think there will be a very large benefit from this.

2 Likes

Is there a better/newer way to achieve this, which have been used after last update on this discussion?
@albanD