Updating MaskRCNN loss function for class imbalance?

I’m creating an instance segmentation model with MaskRCNN. There are four classes A, B, C, and D. The problem is, there are only about 100 samples each of A and B, but 1000 samples each of C and D. This leads to a significant class imbalance.

Given the nature of the data (medical stuff) I cannot easily gather more data. I’d like to modify the loss function to address the class imbalance instead, but I’m not sure how to do this.

This is the loss function that my MaskRCNN model uses (I’m just using the pre-trained model created by PyTorch):

def maskrcnn_loss(mask_logits, proposals, gt_masks, gt_labels, mask_matched_idxs): 
     # type: (Tensor, List[Tensor], List[Tensor], List[Tensor], List[Tensor]) 
     """ 
     Arguments: 
         proposals (list[BoxList]) 
         mask_logits (Tensor) 
         targets (list[BoxList]) 
  
     Return: 
         mask_loss (Tensor): scalar tensor containing the loss 
     """ 
  
     discretization_size = mask_logits.shape[-1] 
     labels = [l[idxs] for l, idxs in zip(gt_labels, mask_matched_idxs)] 
     mask_targets = [ 
         project_masks_on_boxes(m, p, i, discretization_size) 
         for m, p, i in zip(gt_masks, proposals, mask_matched_idxs) 
     ] 
  
     labels = torch.cat(labels, dim=0) 
     mask_targets = torch.cat(mask_targets, dim=0) 
  
     # torch.mean (in binary_cross_entropy_with_logits) doesn't 
     # accept empty tensors, so handle it separately 
     if mask_targets.numel() == 0: 
         return mask_logits.sum() * 0 
  
     mask_loss = F.binary_cross_entropy_with_logits( 
         mask_logits[torch.arange(labels.shape[0], device=labels.device), labels], mask_targets 
     ) 
     return mask_loss 

My question is: how can I modify this function to consider class imbalance and potentially add a weight for each class? (I didn’t write this code and it doesn’t make much sense to me). Any advice would be appreciated!