metrics.py
from sklearn.metrics import accuracy_score
import torch
from monai.metrics import compute_hausdorff_distance
from statistics import mean
import numpy as np
import warnings
class BinarySegMetrics:
def init(self, epsilon=1e-6, hd95_empty_penalty=None):
self.epsilon = epsilon
self.hd95_empty_penalty = hd95_empty_penalty
self.running_metrics = {
'iou': [],
'dice': [],
'accuracy': [],
'hd95': []
}
self.hd95_cases = {
'both_empty': 0, # Both pred and target empty (perfect match)
'false_positive': 0, # Target empty, pred not empty
'false_negative': 0, # Pred empty, target not empty
'both_present': 0 # Both have content
}
def _get_empty_mask_flags(self, preds, targets):
"""Check which masks are empty per batch item"""
preds_flat = preds.view(preds.size(0), -1)
targets_flat = targets.view(targets.size(0), -1)
pred_empty = (preds_flat.sum(dim=1) == 0)
target_empty = (targets_flat.sum(dim=1) == 0)
return pred_empty, target_empty
def _compute_hd95_with_empty_handling(self, preds, targets):
"""Compute HD95 with proper handling of empty masks"""
pred_empty, target_empty = self._get_empty_mask_flags(preds, targets)
batch_size = preds.size(0)
hd95_values = []
for i in range(batch_size):
pred_i = preds[i:i + 1]
target_i = targets[i:i + 1]
if pred_empty[i] and target_empty[i]:
# Both empty - perfect segmentation
hd95_values.append(0.0)
self.hd95_cases['both_empty'] += 1
elif target_empty[i] and not pred_empty[i]:
# False positive - target empty but model predicted something
if self.hd95_empty_penalty is None:
# Use image diagonal as penalty
penalty = np.sqrt(pred_i.shape[-1] ** 2 + pred_i.shape[-2] ** 2)
else:
penalty = self.hd95_empty_penalty
hd95_values.append(penalty)
self.hd95_cases['false_positive'] += 1
elif pred_empty[i] and not target_empty[i]:
# False negative - model missed the target
if self.hd95_empty_penalty is None:
# Use image diagonal as penalty
penalty = np.sqrt(pred_i.shape[-1] ** 2 + pred_i.shape[-2] ** 2)
else:
penalty = self.hd95_empty_penalty
hd95_values.append(penalty)
self.hd95_cases['false_negative'] += 1
else:
# Both have content - compute normal HD95
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
hd95_tensor = compute_hausdorff_distance(
pred_i, target_i,
percentile=95.0,
include_background=False
)
hd95_val = hd95_tensor.item()
if np.isnan(hd95_val) or np.isinf(hd95_val):
# Fallback to maximum distance if computation fails
hd95_val = np.sqrt(pred_i.shape[-1] ** 2 + pred_i.shape[-2] ** 2)
hd95_values.append(hd95_val)
self.hd95_cases['both_present'] += 1
except Exception:
# If HD95 computation fails, use diagonal distance
fallback_val = np.sqrt(pred_i.shape[-1] ** 2 + pred_i.shape[-2] ** 2)
hd95_values.append(fallback_val)
self.hd95_cases['both_present'] += 1
return np.mean(hd95_values) if hd95_values else 0.0
def compute_metrics(self, preds, targets):
"""Compute all segmentation metrics"""
# Ensure binary predictions
preds = (preds > 0.5).float()
targets = (targets > 0.5).float()
# Flatten for pixel-wise computations
preds_flat = preds.view(preds.size(0), -1)
targets_flat = targets.view(targets.size(0), -1)
# IoU computation
intersection = (preds_flat * targets_flat).sum(dim=1)
union = preds_flat.sum(dim=1) + targets_flat.sum(dim=1) - intersection
# Handle empty masks: if both are empty, IoU = 1 (perfect match)
iou = torch.where(
union == 0,
torch.ones_like(union),
intersection / (union + self.epsilon)
)
iou_scalar = iou.mean().item()
# Dice computation
pred_sum = preds_flat.sum(dim=1)
target_sum = targets_flat.sum(dim=1)
# Handle empty masks: if both are empty, Dice = 1 (perfect match)
dice = torch.where(
(pred_sum + target_sum) == 0,
torch.ones_like(intersection),
(2 * intersection + self.epsilon) / (pred_sum + target_sum + self.epsilon)
)
dice_scalar = dice.mean().item()
# Accuracy computation
preds_np = preds.view(-1).cpu().numpy()
targets_np = targets.view(-1).cpu().numpy()
accuracy = accuracy_score(targets_np, preds_np)
# HD95 computation with empty mask handling
hd95_scalar = self._compute_hd95_with_empty_handling(preds, targets)
# Store all metrics
self.running_metrics['iou'].append(iou_scalar)
self.running_metrics['dice'].append(dice_scalar)
self.running_metrics['accuracy'].append(accuracy)
self.running_metrics['hd95'].append(hd95_scalar)
return iou_scalar, dice_scalar, accuracy, hd95_scalar
def get_epoch_metrics(self):
"""Get average metrics for the epoch"""
epoch_metrics = {}
for name, values in self.running_metrics.items():
if values:
epoch_metrics[name] = mean(values)
else:
epoch_metrics[name] = 0.0
return epoch_metrics
def get_hd95_case_summary(self):
"""Get summary of HD95 computation cases"""
total = sum(self.hd95_cases.values())
if total == 0:
return self.hd95_cases
summary = {}
for case, count in self.hd95_cases.items():
summary[case] = {
'count': count,
'percentage': (count / total) * 100
}
return summary
def reset(self):
"""Reset all metrics for new epoch"""
self.running_metrics = {name: [] for name in self.running_metrics.keys()}
self.hd95_cases = {
'both_empty': 0,
'false_positive': 0,
'false_negative': 0,
'both_present': 0
}
loss.py
import torch
import torch.nn as nn
from statistics import mean
import torch.nn.functional as F
class BCELoss(nn.Module):
def init(self):
super(BCELoss, self).init()
self.bceloss = nn.BCELoss()
def forward(self, pred, target):
size = pred.size(0)
pred_ = pred.view(size, -1)
target_ = target.view(size, -1)
return self.bceloss(pred_, target_)
class DiceLoss(nn.Module):
def init(self, smooth=1.0):
super(DiceLoss, self).init()
self.smooth = smooth
def forward(self, pred, target):
size = pred.size(0)
pred_ = pred.view(size, -1)
target_ = target.view(size, -1)
intersection = pred_ * target_
dice_score = (2 * intersection.sum(1) + self.smooth) / (pred_.sum(1) + target_.sum(1) + self.smooth)
dice_loss = 1 - dice_score.sum() / size
return dice_loss
class BceDiceLoss(nn.Module):
def init(self, wb=1, wd=1):
super(BceDiceLoss, self).init()
self.bce = BCELoss()
self.dice = DiceLoss()
self.wb = wb
self.wd = wd
def forward(self, pred, target):
bceloss = self.bce(pred, target)
diceloss = self.dice(pred, target)
return self.wd * diceloss + self.wb * bceloss
def get_individual_losses(self, pred, target):
"""Return individual loss components"""
bceloss = self.bce(pred, target)
diceloss = self.dice(pred, target)
return {'bce': bceloss.item(), 'dice': diceloss.item()}
class LossTracker:
def init(self, loss_type=‘bce_dice’, **loss_kwargs):
self.loss_type = loss_type
if loss_type == 'bce_dice':
self.loss_fn = BceDiceLoss(**loss_kwargs)
self.loss_names = ['dice', 'bce']
self.running_losses = {name: [] for name in self.loss_names}
def compute_losses(self, preds, targets):
combined_loss = self.loss_fn(preds, targets)
if self.loss_type == 'bce_dice':
individual_losses = self.loss_fn.get_individual_losses(preds, targets)
self.running_losses['bce'].append(individual_losses['bce'])
self.running_losses['dice'].append(individual_losses['dice'])
return combined_loss
def get_epoch_losses(self):
epoch_losses = {}
for name, values in self.running_losses.items():
if values:
epoch_losses[name] = mean(values)
else:
epoch_losses[name] = 0.0
return epoch_losses
def reset(self):
self.running_losses = {name: [] for name in self.running_losses.keys()}
i am using 0.6 dice loss and 0.4 bce loss
since my task is medical image binary segmentation and i am sending images and mask in patches so many patches are empty(only background)