I am building up a cascade of neural networks and I would like to backpropagate the main loss back to the DNNs and also compute an auxillary loss back to each DNN.
I am trying to figure out what is the best practice when building such a model and how to make sure that my losses are computed properly. Do I build a single
torch.nn.Module and a single optimizer, or do I have to create separate modules and optimizers for each network? Also I am likely to have more than three cascaded DNNs.
import torch from torch import nn, optim class MasterNetwork(nn.Module): def init(self): super(MasterNetwork, self).__init__() dnn1 = nn.ModuleList() dnn2 = nn.ModuleList() dnn3 = nn.ModuleList() def forward(self, x, z1, z2): out1 = dnn1(x) out2 = dnn2(out1 + z1) out3 = dnn3(out2 + z2) return [out1, out2, out3] def LossFunction(in): # do stuff return loss # loss is a scalar value def ac_loss_1_fn(in): # do stuff return loss # loss is a scalar value def ac_loss_2_fn(in): # do stuff return loss # loss is a scalar value def ac_loss_3_fn(in): # do stuff return loss # loss is a scalar value model = MasterNetwork() optimizer = optim.Adam(model.parameters()) input = torch.tensor() z1 = torch.tensor() z2 = torch.tensor() outputs = model(input, z1, z2) main_loss = LossFunction(outputs) ac1_loss = ac_loss_1_fn(outputs) ac2_loss = ac_loss_2_fn(outputs) ac3_loss = ac_loss_3_fn(outputs) optimizer.zero_grad() ''' This is where I am uncertain about how to backpropagate the AC losses for each DNN in addition to the main loss. ''' optimizer.step()
This would creating a
nn.Module class and optimizer for each DNN and then forwarding the loss to the next DNN.
I would prefer to have a solution for approach a) since it is less tedious and I don’t have to deal with tuning multiple optimizers. However, I am not sure if this is possible. There was a similar question about backpropagating multiple losses, however, I was not able to understand how combining the losses would work for the distinct components.
I have also posted this question on stackoverflow because the website was down for me.