Correctly Backpropagating Losses to different Models

I am currently working on implementing a Vision Transformer Architecture, which consists of a Shared Encoder and three Decoders for three different tasks, primarily focusing on white balancing. Each of these four components has been implemented as its own nn.Module.

During the training phase, all four components are trained together. The Primary Decoder produces three separate losses (L1, L_ang, L_surf), while the two Auxiliary Decoders each have one loss (L_A and L_E). The loss of the Encoder should be the sum of these five losses (+ L_cont that I’ve not implemented yet).

My challenge lies in correctly backpropagating these losses and informing PyTorch which model each loss should be backpropagated to. Attempting to use multiple .backward() calls results in an error. I have come across suggestions to simply sum all the losses together, allowing the method to automatically handle the backpropagation. However, I do not fully understand this approach and I am unsure if it is effectively assigning the losses as intended.

Here is the code I am currently using for the backward pass within the Training function:

    with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
                         pred, achromatic_map, pred_hed = get_decoders_outputs(modules, *modules["encoder"](input))
                         total_loss = get_total_loss(input, pred, gt, achromatic_map, pred_hed, gt_hed, loss_weights)

    # Loss
    update_modules(modules, optimizers, grad_scaler)

Additionally, here is the code for computing the losses:

def get_primary_loss(input, pred, gt, loss_weights):
    loss_score = 0 

    loss_score += L1_loss(pred, gt, loss_weights[0])

    loss_score += angular_loss(input, pred, gt, loss_weights[1])

    loss_score += surf_loss(pred, gt, loss_weights[2])

    return loss_score

def get_total_loss(input, pred, gt, achromatic_map, pred_hed, gt_hed, loss_weights):

    primary_loss = get_primary_loss(input, pred, gt, loss_weights[2:5])
    achromatic_loss = achromatic_pixel_detection_loss(gt, achromatic_map, loss_weights[0])
    edge_loss = edge_detection_loss(pred_hed, gt_hed, loss_weights[1])
    total_loss = primary_loss + edge_loss + achromatic_loss

    return total_loss

Summing the losses and calling backward on it will backpropagate through each of the summands and should work. In the end the addition of the losses is a differentiable operation which Autograd understands.
If you call backward() multiple times I assume you are running into an error claiming the computation graph was already freed which would indicate that the losses share some common layers/parameters.
In this case, the first backward call will free the intermediate activations needed to compute the gradients. You could work around this issue by using retain_graph=True in the first backward calls and set it back to False in the last one, but summing the losses should also work.