How to train multi-branch output network?

Hello, everyone!

I am working to build multi-branch network.
Here is what I am working on.

Each data sample in the batch has its own class, so each sample goes into a decoder of the same class as itself.
How to train this kind of network?

Here is a training step in my current code.

params = list(Encoder.parameters()) + list(Decoder1.parameters()) \
                    + list(Decoder2.parameters()) + list(Decoder3.parameters())
optim = torch.optim.Adadelta(params, learning_rate)

encoder_output = Encoder(input)
optim.zero_grad()
loss_1, loss_2, loss_3 = 0, 0, 0
# split batch
if encoder_output[1_batch_ind].shape[0] != 0:
     output1 = Decoder1(encoder_output[1_batch_ind])
     loss_1 = loss_fn(output1, ground_truth[1_batch_ind])
                
if encoder_output[2_batch_ind].shape[0] != 0:
     output2 = Decoder2(encoder_output[2_batch_ind])
     loss_2 = loss_fn(output2, ground_truth[2_batch_ind])
                
if encoder_output[3_batch_ind].shape[0] != 0:
     output3 = Decoder3(encoder_output[3_batch_ind])
     loss_3 = loss_fn(output3, ground_truth[3_batch_ind])
loss = loss_1 + loss_2 + loss_3
loss.backward()
optim.step()             
                

What I have expected is each loss_# is backpropagated through each Decoder#, only and Encoder is trained by a total sum of loss, which means the backpropagation happens like
loss_# → Decoder#,
total_loss → Encoder.

Does the code above run as I have expected??

1 Like

Yes, that should work as expected.
To verify it, you could zero out some losses and check the gradients in each module separately.

1 Like