Backpropagating multiple parallel losses in Pytorch

I am building up a cascade of neural networks and I would like to backpropagate the main loss back to the DNNs and also compute an auxillary loss back to each DNN.

Image

I am trying to figure out what is the best practice when building such a model and how to make sure that my losses are computed properly. Do I build a single torch.nn.Module and a single optimizer, or do I have to create separate modules and optimizers for each network? Also I am likely to have more than three cascaded DNNs.

Approach a)

import torch
from torch import nn, optim

class MasterNetwork(nn.Module):
    def init(self):
          super(MasterNetwork, self).__init__()
          dnn1 = nn.ModuleList()
          dnn2 = nn.ModuleList()
          dnn3 = nn.ModuleList()

    def forward(self, x, z1, z2):
          out1 = dnn1(x)
          out2 = dnn2(out1 + z1)
          out3 = dnn3(out2 + z2)

          return [out1, out2, out3]

def LossFunction(in):
       # do stuff
       return loss # loss is a scalar value
def ac_loss_1_fn(in):
       # do stuff
       return loss # loss is a scalar value
def ac_loss_2_fn(in):
       # do stuff
       return loss # loss is a scalar value
def ac_loss_3_fn(in):
       # do stuff
       return loss # loss is a scalar value

model = MasterNetwork()
optimizer = optim.Adam(model.parameters())

input = torch.tensor()
z1 = torch.tensor()
z2 = torch.tensor()

outputs = model(input, z1, z2)

main_loss = LossFunction(outputs[2])
ac1_loss = ac_loss_1_fn(outputs[0])
ac2_loss = ac_loss_2_fn(outputs[1])
ac3_loss = ac_loss_3_fn(outputs[2])

optimizer.zero_grad()

'''
This is where I am uncertain about how to backpropagate the AC losses for each DNN
in addition to the main loss.
'''

optimizer.step()

Approach b)
This would creating a nn.Module class and optimizer for each DNN and then forwarding the loss to the next DNN.

I would prefer to have a solution for approach a) since it is less tedious and I don’t have to deal with tuning multiple optimizers. However, I am not sure if this is possible. There was a similar question about backpropagating multiple losses, however, I was not able to understand how combining the losses would work for the distinct components.

I have also posted this question on stackoverflow because the website was down for me.

1 Like

you can minimize weighted sum of your losses.

Can you elaborate further please? My concern with summing the losses is that it would backpropagate across all elements and not the specific components they are meant to optimize.

import torch.nn as nn

 
x = torch.ones(1,10)
x.requires_grad = True

y = 2*torch.ones(1,10)
y.requires_grad = True

l1 = (y**2).mean()
l2 = (x**2).mean()
l = l1 +l2
l.backward()
print(x.grad)
print(y.grad)

x.grad = tensor([[0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000,0.2000]])
ygrad = tensor([[0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000,0.4000]])

it works .