I am working on a binary classification problem with 3D volumes, with mask info available. Volumes labeled by 0 contain the class of interest, while the ones labeled by 1 are ‘normal’. Likewise, masks of class 0 have values 0 and 1, while masks of class 1 have values of 0 only (i.e. no info).
I want to incorporate the GuidedGradCam explanations into the optimization loss in such a way that the overall loss is an aggregation of the classification loss (BCE) and a loss that measures the similarity between the mask and the attributions of class 0 (again a BCE loss).
Below is an instance of my code for epoch training. In turn this is fed into a training method which I do not include here.
I would like your assistance on whether the total loss gradients and weight updates are calculated correctly, as I do see signs of learning so far.
Any comment/advice is welcome.
import torch
import warnings
import gc
from torch.cuda.amp import autocast, GradScaler
from captum.attr import GuidedGradCam
warnings.filterwarnings("ignore", category=UserWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train_epoch(training_dataloader, model, optimizer, scaler):
batch_clf_losses = []
batch_xai_losses = []
target_layer = model.conv4[0] #for GuidedGradCam
for steps, batch in enumerate(training_dataloader):
volumes, masks, labels = batch
volumes, masks, labels = volumes.to(device), masks.to(device), labels.to(device)
with autocast(): # Mixed precision context to handle memory overload
outputs = model(volumes.unsqueeze(1))
bce_loss = torch.nn.BCEWithLogitsLoss(pos_weight=None)
batch_clf_loss = bce_loss(outputs, labels.float()) #this is batch normalized by default
grad_cam = GuidedGradCam(model, conv_layer)
batch_xai_loss = 0
class0_batch_points = 0
for i in range(labels.shape[0]):
if labels[i]==0: #volumes with label 0 contribute to the xai loss calculation
#volumes with label 1 have masks with 0 values only (no class info), thus nothing to contribute
class0_batch_points += 1
grad_cam_attr = grad_cam.attribute(volumes[i].unsqueeze(0).unsqueeze(0), 0)
bce_loss_xai = torch.nn.BCEWithLogitsLoss()
batch_xai_loss += bce_loss_xai(grad_cam_attr, masks[i].unsqueeze(0).unsqueeze(0)float()).item()
if vol_class0_batch_points>0:
batch_xai_loss = (1/2) * (batch_xai_loss/class0_batch_points) #normalize here, because bce xai loss is computed per class 0 point above
total_loss = batch_clf_loss + batch_xai_loss
epoch_clf_loss = sum(batch_clf_losses) / (steps+1) #all batches contribute to classification loss
epoch_xai_loss = sum(batch_xai_losses) / len([x for x in batch_xai_losses if x!=0]) #consider only batches that contribute to xai loss
return epoch_clf_loss, epoch_xai_loss