Hi I have this custom loss function. I am using torch.sum to obtain boundary pixels from a binary matrix gt_b. gt_b is a float tensor and it can take 1.000 and -0.000 as input values. When its inputs all -0.000 this torch.sum() gives pretty inacurate values. I am suspecting from a overflow but how to deal with that ?
'''
inputs of this loss function will be;
predicted_mask of shape (B,C,H,W) as torch.flaot tensor with dtype torch.float32
mask of shape (B,C,H,W) torch.flaot tensor with dtype torch.float32
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from l1 import l1_distance_transform
def get_bounday_map(x):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
'''
Calculates the boundary map of Input x of shape (B,1,H,W)
and returns the boundary map of shape (B,1,H,W)
'''
kernel_v = [[0, -1, 0],
[0, 0, 0],
[0, 1, 0]]
kernel_h = [[0, 0, 0],
[-1, 0, 1],
[0, 0, 0]]
kernel_h = torch.cuda.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.cuda.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
weight_h = nn.Parameter(data=kernel_h, requires_grad=False)
weight_v = nn.Parameter(data=kernel_v, requires_grad=False)
x_v = F.conv2d(x, weight_v, padding=1)
x_h = F.conv2d(x, weight_h, padding=1)
y = torch.sqrt(torch.pow(x_v, 2) + torch.pow(x_h, 2) + 1e-6)
z = torch.nn.functional.threshold(y, 0.5, 0, inplace=False)
l = torch.nn.functional.threshold(-z, -0.5, 1, inplace=False)
return l
class BoundaryLoss(nn.Module):
def __init__(self, alpha = 1, beta = 0):
super().__init__()
self.alpha = alpha
self.beta = beta
def forward(self, pred, gt):
"""
Input:
- pred: the output from model (before softmax)
shape (B, C, H, W)
- gt: ground truth map
shape (B, C, H, w)
Return:
- boundary loss, averaged over mini-bathc
Note: This is for binary tasks
"""
n, c, _, _ = pred.shape
bce_loss = nn.BCEWithLogitsLoss()(pred, gt)
# gt = gt.float()
# pred = pred.float()
# # Using softmax so that predicted map can be distributed in [0, 1].
# pred = torch.softmax(pred, dim=1)
#Since it is a binary task, using only foreground or background .
gt = gt[:, 0:1, :, :]
pred = pred [:, 0:1, :, :]
#Getting boundary maps.
gt_b = get_bounday_map(gt)
pred_b = get_bounday_map(pred)
nb_boundary = torch.sum(gt_b)
#Getting distance transform of ground truth
gt_dt = l1_distance_transform(1 - gt_b)
#Calculating the boundary loss with pixelwise multiplication
boundary_loss = 1e-7
hadamard = torch.mul(gt_dt, pred_b)
# print(hadamard.max(), hadamard.min())
# print(torch.sum(torch.softmax(hadamard)))
boundary_loss += (torch.sum(hadamard) / nb_boundary)
#Return linear product of BCE and Boundary Loss
return (self.alpha * boundary_loss + self.beta * bce_loss)