How to define multiple optimizers in the following scenario where loses are being calculated in multiple branches?

I have defined 7 separate networks in 7 separate classes each of them inherit from nn.Module. I am calculating losses at 4 separate places as follows:

params = list(network1.parameters()) + list(network2.parameters()) + list(network3.parameters()) + \
         list(network4.parameters()) + list(network5.parameters())+ list(network6.parameters()) + \

optimizer = torch.optim.Adam(params, lr=1e-4)

for iter in range(max_epoch):
    network1_output = network1(input)
    network2_output = network2(network1_output)
    network3_output = network3(network1_output)
    network4_output = network4(network2_output)
    loss1 = loss_criterion(netwrok4_output, target_1)
    network5_output = network5(network2_output)
    loss2 = loss_criterion(network5_output, target_2)
    network6_output = network6(network3_output)
    loss3 = loss_criterion(network6_output, target_3)
    network7_output = network7(network3_output)
    loss4 = loss_criterion(network7_output, target_4)

I would like to know if the above implementation is correct. From what I have understood, gradients will be accumulated throughout the backward pass till another optimizer.zero_grad() is called and thus one can use a single optimizer to adjust each of the network weights.

I would appreciate any corrections in my intuition as well as any better way of implementing the same.

Be aware that loss1.backward() will backpropagate through network4, network2 and network1, which is probably what you want.

If one of the losses seems to be having too much influence, then you could try multiplying it by a scalar < 1.

All in all though, appart from the occasional misspelling of netwrok, it looks good to me.

Thanks for the reply, I will make sure to use network instead of netwrok everywhere :slight_smile: .

@jpeg729 Since I am sharing part of the network for multiple configurations, I am getting the error that I should set retain_graph=True. But doing so runs me out of memory. Is there a better way of describing the graph and sharing parts of the network so that this doesn’t happen. Please let me know.

@jpeg729 No worries, I figured it out. I have a GAN as part of my network and I forgot to detach.