I am currently working on implementing a Vision Transformer Architecture, which consists of a Shared Encoder and three Decoders for three different tasks, primarily focusing on white balancing. Each of these four components has been implemented as its own nn.Module
.
During the training phase, all four components are trained together. The Primary Decoder produces three separate losses (L1, L_ang, L_surf), while the two Auxiliary Decoders each have one loss (L_A and L_E). The loss of the Encoder should be the sum of these five losses (+ L_cont that I’ve not implemented yet).
My challenge lies in correctly backpropagating these losses and informing PyTorch which model each loss should be backpropagated to. Attempting to use multiple .backward()
calls results in an error. I have come across suggestions to simply sum all the losses together, allowing the method to automatically handle the backpropagation. However, I do not fully understand this approach and I am unsure if it is effectively assigning the losses as intended.
Here is the code I am currently using for the backward pass within the Training function:
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
pred, achromatic_map, pred_hed = get_decoders_outputs(modules, *modules["encoder"](input))
total_loss = get_total_loss(input, pred, gt, achromatic_map, pred_hed, gt_hed, loss_weights)
# Loss
set_zero_grad(optimizers)
grad_scaler.scale(total_loss).backward()
update_modules(modules, optimizers, grad_scaler)
Additionally, here is the code for computing the losses:
def get_primary_loss(input, pred, gt, loss_weights):
loss_score = 0
loss_score += L1_loss(pred, gt, loss_weights[0])
loss_score += angular_loss(input, pred, gt, loss_weights[1])
loss_score += surf_loss(pred, gt, loss_weights[2])
return loss_score
def get_total_loss(input, pred, gt, achromatic_map, pred_hed, gt_hed, loss_weights):
primary_loss = get_primary_loss(input, pred, gt, loss_weights[2:5])
achromatic_loss = achromatic_pixel_detection_loss(gt, achromatic_map, loss_weights[0])
edge_loss = edge_detection_loss(pred_hed, gt_hed, loss_weights[1])
total_loss = primary_loss + edge_loss + achromatic_loss
return total_loss