I have defined 7 separate networks in 7 separate classes each of them inherit from nn.Module. I am calculating losses at 4 separate places as follows:
params = list(network1.parameters()) + list(network2.parameters()) + list(network3.parameters()) + \
list(network4.parameters()) + list(network5.parameters())+ list(network6.parameters()) + \
list(network7.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)
for iter in range(max_epoch):
optimizer.zero_grad()
network1_output = network1(input)
network2_output = network2(network1_output)
network3_output = network3(network1_output)
network4_output = network4(network2_output)
loss1 = loss_criterion(netwrok4_output, target_1)
loss1.backward()
network5_output = network5(network2_output)
loss2 = loss_criterion(network5_output, target_2)
loss2.backward()
network6_output = network6(network3_output)
loss3 = loss_criterion(network6_output, target_3)
loss3.backward()
network7_output = network7(network3_output)
loss4 = loss_criterion(network7_output, target_4)
loss4.backward()
optimizer.step()
I would like to know if the above implementation is correct. From what I have understood, gradients will be accumulated throughout the backward pass till another optimizer.zero_grad()
is called and thus one can use a single optimizer to adjust each of the network weights.
I would appreciate any corrections in my intuition as well as any better way of implementing the same.