Val loss and metrics generally better than Train metrics

metrics.py
from sklearn.metrics import accuracy_score
import torch
from monai.metrics import compute_hausdorff_distance
from statistics import mean
import numpy as np
import warnings

class BinarySegMetrics:
def init(self, epsilon=1e-6, hd95_empty_penalty=None):

    self.epsilon = epsilon
    self.hd95_empty_penalty = hd95_empty_penalty
    self.running_metrics = {
        'iou': [],
        'dice': [],
        'accuracy': [],
        'hd95': []
    }
    self.hd95_cases = {
        'both_empty': 0,  # Both pred and target empty (perfect match)
        'false_positive': 0,  # Target empty, pred not empty
        'false_negative': 0,  # Pred empty, target not empty
        'both_present': 0  # Both have content
    }

def _get_empty_mask_flags(self, preds, targets):
    """Check which masks are empty per batch item"""
    preds_flat = preds.view(preds.size(0), -1)
    targets_flat = targets.view(targets.size(0), -1)

    pred_empty = (preds_flat.sum(dim=1) == 0)
    target_empty = (targets_flat.sum(dim=1) == 0)

    return pred_empty, target_empty

def _compute_hd95_with_empty_handling(self, preds, targets):
    """Compute HD95 with proper handling of empty masks"""
    pred_empty, target_empty = self._get_empty_mask_flags(preds, targets)
    batch_size = preds.size(0)
    hd95_values = []

    for i in range(batch_size):
        pred_i = preds[i:i + 1]
        target_i = targets[i:i + 1]

        if pred_empty[i] and target_empty[i]:
            # Both empty - perfect segmentation
            hd95_values.append(0.0)
            self.hd95_cases['both_empty'] += 1

        elif target_empty[i] and not pred_empty[i]:
            # False positive - target empty but model predicted something
            if self.hd95_empty_penalty is None:
                # Use image diagonal as penalty
                penalty = np.sqrt(pred_i.shape[-1] ** 2 + pred_i.shape[-2] ** 2)
            else:
                penalty = self.hd95_empty_penalty
            hd95_values.append(penalty)
            self.hd95_cases['false_positive'] += 1

        elif pred_empty[i] and not target_empty[i]:
            # False negative - model missed the target
            if self.hd95_empty_penalty is None:
                # Use image diagonal as penalty
                penalty = np.sqrt(pred_i.shape[-1] ** 2 + pred_i.shape[-2] ** 2)
            else:
                penalty = self.hd95_empty_penalty
            hd95_values.append(penalty)
            self.hd95_cases['false_negative'] += 1

        else:
            # Both have content - compute normal HD95
            try:
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore", UserWarning)
                    hd95_tensor = compute_hausdorff_distance(
                        pred_i, target_i,
                        percentile=95.0,
                        include_background=False
                    )
                    hd95_val = hd95_tensor.item()

                    if np.isnan(hd95_val) or np.isinf(hd95_val):
                        # Fallback to maximum distance if computation fails
                        hd95_val = np.sqrt(pred_i.shape[-1] ** 2 + pred_i.shape[-2] ** 2)

                    hd95_values.append(hd95_val)
                    self.hd95_cases['both_present'] += 1

            except Exception:
                # If HD95 computation fails, use diagonal distance
                fallback_val = np.sqrt(pred_i.shape[-1] ** 2 + pred_i.shape[-2] ** 2)
                hd95_values.append(fallback_val)
                self.hd95_cases['both_present'] += 1

    return np.mean(hd95_values) if hd95_values else 0.0

def compute_metrics(self, preds, targets):
    """Compute all segmentation metrics"""
    # Ensure binary predictions
    preds = (preds > 0.5).float()
    targets = (targets > 0.5).float()

    # Flatten for pixel-wise computations
    preds_flat = preds.view(preds.size(0), -1)
    targets_flat = targets.view(targets.size(0), -1)

    # IoU computation
    intersection = (preds_flat * targets_flat).sum(dim=1)
    union = preds_flat.sum(dim=1) + targets_flat.sum(dim=1) - intersection

    # Handle empty masks: if both are empty, IoU = 1 (perfect match)
    iou = torch.where(
        union == 0,
        torch.ones_like(union),
        intersection / (union + self.epsilon)
    )
    iou_scalar = iou.mean().item()

    # Dice computation
    pred_sum = preds_flat.sum(dim=1)
    target_sum = targets_flat.sum(dim=1)

    # Handle empty masks: if both are empty, Dice = 1 (perfect match)
    dice = torch.where(
        (pred_sum + target_sum) == 0,
        torch.ones_like(intersection),
        (2 * intersection + self.epsilon) / (pred_sum + target_sum + self.epsilon)
    )
    dice_scalar = dice.mean().item()

    # Accuracy computation
    preds_np = preds.view(-1).cpu().numpy()
    targets_np = targets.view(-1).cpu().numpy()
    accuracy = accuracy_score(targets_np, preds_np)

    # HD95 computation with empty mask handling
    hd95_scalar = self._compute_hd95_with_empty_handling(preds, targets)

    # Store all metrics
    self.running_metrics['iou'].append(iou_scalar)
    self.running_metrics['dice'].append(dice_scalar)
    self.running_metrics['accuracy'].append(accuracy)
    self.running_metrics['hd95'].append(hd95_scalar)

    return iou_scalar, dice_scalar, accuracy, hd95_scalar

def get_epoch_metrics(self):
    """Get average metrics for the epoch"""
    epoch_metrics = {}
    for name, values in self.running_metrics.items():
        if values:
            epoch_metrics[name] = mean(values)
        else:
            epoch_metrics[name] = 0.0
    return epoch_metrics

def get_hd95_case_summary(self):
    """Get summary of HD95 computation cases"""
    total = sum(self.hd95_cases.values())
    if total == 0:
        return self.hd95_cases

    summary = {}
    for case, count in self.hd95_cases.items():
        summary[case] = {
            'count': count,
            'percentage': (count / total) * 100
        }
    return summary

def reset(self):
    """Reset all metrics for new epoch"""
    self.running_metrics = {name: [] for name in self.running_metrics.keys()}
    self.hd95_cases = {
        'both_empty': 0,
        'false_positive': 0,
        'false_negative': 0,
        'both_present': 0
    }

loss.py
import torch
import torch.nn as nn
from statistics import mean
import torch.nn.functional as F

class BCELoss(nn.Module):
def init(self):
super(BCELoss, self).init()
self.bceloss = nn.BCELoss()

def forward(self, pred, target):
    size = pred.size(0)
    pred_ = pred.view(size, -1)
    target_ = target.view(size, -1)
    return self.bceloss(pred_, target_)

class DiceLoss(nn.Module):
def init(self, smooth=1.0):
super(DiceLoss, self).init()
self.smooth = smooth

def forward(self, pred, target):
    size = pred.size(0)
    pred_ = pred.view(size, -1)
    target_ = target.view(size, -1)
    intersection = pred_ * target_
    dice_score = (2 * intersection.sum(1) + self.smooth) / (pred_.sum(1) + target_.sum(1) + self.smooth)
    dice_loss = 1 - dice_score.sum() / size
    return dice_loss

class BceDiceLoss(nn.Module):
def init(self, wb=1, wd=1):
super(BceDiceLoss, self).init()
self.bce = BCELoss()
self.dice = DiceLoss()
self.wb = wb
self.wd = wd

def forward(self, pred, target):
    bceloss = self.bce(pred, target)
    diceloss = self.dice(pred, target)
    return self.wd * diceloss + self.wb * bceloss

def get_individual_losses(self, pred, target):
    """Return individual loss components"""
    bceloss = self.bce(pred, target)
    diceloss = self.dice(pred, target)
    return {'bce': bceloss.item(), 'dice': diceloss.item()}

class LossTracker:
def init(self, loss_type=‘bce_dice’, **loss_kwargs):
self.loss_type = loss_type

    if loss_type == 'bce_dice':
        self.loss_fn = BceDiceLoss(**loss_kwargs)
        self.loss_names = ['dice', 'bce']

    self.running_losses = {name: [] for name in self.loss_names}

def compute_losses(self, preds, targets):
    combined_loss = self.loss_fn(preds, targets)

    if self.loss_type == 'bce_dice':
        individual_losses = self.loss_fn.get_individual_losses(preds, targets)
        self.running_losses['bce'].append(individual_losses['bce'])
        self.running_losses['dice'].append(individual_losses['dice'])

    return combined_loss

def get_epoch_losses(self):
    epoch_losses = {}
    for name, values in self.running_losses.items():
        if values:
            epoch_losses[name] = mean(values)
        else:
            epoch_losses[name] = 0.0
    return epoch_losses

def reset(self):
    self.running_losses = {name: [] for name in self.running_losses.keys()}

i am using 0.6 dice loss and 0.4 bce loss
since my task is medical image binary segmentation and i am sending images and mask in patches so many patches are empty(only background)

Hello @Areej_Ahmad1

When your validation loss and metrics outperform your training ones, it can suggest a few things. One common reason is that the validation dataset might be easier for the model to learn from than the training dataset.

I recommend revisiting how your training and validation sets are split. Performing an Exploratory Data Analysis (EDA) on both can help you determine whether the validation set has a similar distribution of image representations as the training set. Ensuring consistency between the two can lead to more reliable model evaluation.


Machine Learning Engineer at RidgeRun.ai
Contact us: support@ridgerun.ai

1 Like

i used various augmentation techniques for training set but not for validation set

Hello @Areej_Ahmad1,

I had a similar problem during my training. I used a cluster-based split for my dataset to create train and validation splits. The issue was that my training data had very similar data appearing in the validation set as well.

What type of splitting are you currently using?

Hello @Areej_Ahmad1

That sounds like a good clue on why your validation set metrics are better.

Generally speaking, you should augment your whole dataset, and then split it into the training, validation and testing sets with equal distributions. If you augment only the training set, this would make the validation set easier to predict for your model, since it was trained with much more complex images.


Machine Learning Engineer at RidgeRun.ai
Contact us: support@ridgerun.ai