Incorporating explanations in loss function

I am working on a binary classification problem with 3D volumes, with mask info available. Volumes labeled by 0 contain the class of interest, while the ones labeled by 1 are ‘normal’. Likewise, masks of class 0 have values 0 and 1, while masks of class 1 have values of 0 only (i.e. no info).

I want to incorporate the GuidedGradCam explanations into the optimization loss in such a way that the overall loss is an aggregation of the classification loss (BCE) and a loss that measures the similarity between the mask and the attributions of class 0 (again a BCE loss).

Below is an instance of my code for epoch training. In turn this is fed into a training method which I do not include here.

I would like your assistance on whether the total loss gradients and weight updates are calculated correctly, as I do see signs of learning so far.

Any comment/advice is welcome.

import torch
import warnings
import gc
from torch.cuda.amp import autocast, GradScaler
from captum.attr import GuidedGradCam

warnings.filterwarnings("ignore", category=UserWarning)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

def train_epoch(training_dataloader, model, optimizer, scaler):
    model.to(device)
    model.train()

    batch_clf_losses = []
    batch_xai_losses = []

    target_layer = model.conv4[0] #for GuidedGradCam

    for steps, batch in enumerate(training_dataloader):

        volumes, masks, labels = batch
        volumes, masks, labels = volumes.to(device), masks.to(device), labels.to(device)

        optimizer.zero_grad()

        with autocast():  # Mixed precision context to handle memory overload

            outputs = model(volumes.unsqueeze(1))

            bce_loss = torch.nn.BCEWithLogitsLoss(pos_weight=None)
            batch_clf_loss = bce_loss(outputs, labels.float()) #this is batch normalized by default

            grad_cam = GuidedGradCam(model, conv_layer)

            batch_xai_loss = 0
            class0_batch_points = 0

            for i in range(labels.shape[0]):
                
                if labels[i]==0: #volumes with label 0 contribute to the xai loss calculation
                                 #volumes with label 1 have masks with 0 values only (no class info), thus nothing to contribute
                    
                    class0_batch_points += 1

                    grad_cam_attr = grad_cam.attribute(volumes[i].unsqueeze(0).unsqueeze(0), 0)

                    bce_loss_xai = torch.nn.BCEWithLogitsLoss()
                    batch_xai_loss += bce_loss_xai(grad_cam_attr, masks[i].unsqueeze(0).unsqueeze(0)float()).item() 

                    gc.collect()
                    torch.cuda.empty_cache()

        if vol_class0_batch_points>0:
            batch_xai_loss = (1/2) * (batch_xai_loss/class0_batch_points) #normalize here, because bce xai loss is computed per class 0 point above

        total_loss = batch_clf_loss + batch_xai_loss

        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        gc.collect()
        torch.cuda.empty_cache()

        batch_clf_losses.append(batch_clf_loss)
        batch_xai_losses.append(batch_xai_loss)

    epoch_clf_loss = sum(batch_clf_losses) / (steps+1) #all batches contribute to classification loss
    epoch_xai_loss = sum(batch_xai_losses) / len([x for x in batch_xai_losses if x!=0]) #consider only batches that contribute to xai loss

    return epoch_clf_loss, epoch_xai_loss