I want to do a binary boundary losses inspired from https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/boundary_loss.py and https://arxiv.org/abs/1812.07032. I have seen the use of with torch.no_grad
in forward method in loss. I am not sure how torch.no_grad is recommended in the forward method. Since it is done only in the labels which are not yet attach to the computational graph does it is a good idea to use it. Will it influence training.
Is better to have multiple inputs like def forward(self, preds, targets, distance_map):
?
"""
Inspired from https://github.com/JunMa11/SegLoss/blob/master/losses_pytorch/boundary_loss.py.
Which list a lot of pytorch losses for segmentatoin.
"""
import numpy as np
from scipy.ndimage import distance_transform_edt
import torch
def compute_edts_forhdloss(segmentation):
res = np.zeros(segmentation.shape)
for i in range(segmentation.shape[0]):
posmask = segmentation[i]
negmask = ~posmask
res[i] = distance_transform_edt(posmask) + distance_transform_edt(negmask)
return res
class BinaryBDLoss(torch.nn.Module):
def __init__(self, num_channel: int, ndim=4, per_channel: bool = True, reduction: str = "mean"):
"""
compute boudary loss
only compute the loss of foreground
ref: https://github.com/LIVIAETS/surface-loss/blob/108bd9892adca476e6cdf424124bc6268707498e/losses.py#L74
"""
super().__init__()
assert ndim in [4, 5]
self.per_channel = per_channel
self.num_channel = num_channel
self.ndim = ndim
self.axis = list(range(ndim))
self.reduction = reduction
if self.num_channel > 1 & self.per_channel:
if self.ndim == 5:
self.axis = [0, -3, -2, -1]
if self.ndim == 4:
self.axis = [0, -2, -1]
# self.do_bg = do_bg
def forward(self, preds, targets):
"""
preds: (batch_size, class, x,y)
target: ground truth, shape: (batch_size, 1, x,y)
bound: precomputed distance map, shape (batch_size, class, x,y)
"""
if preds.shape[1] != self.num_channel:
preds = preds.view(preds.shape[0], -1, *preds.shape[1:])
targets = targets.view(*preds.shape)
# distance = distance.view(*preds.shape)
with torch.no_grad():
bound = compute_edts_forhdloss(targets.cpu().numpy() > 0.5)
bound = torch.from_numpy(bound)
bound.requires_grad = True
bound.to(preds.device)
bound.type(preds.dtype)
# print('preds shape: ', preds.shape)
# Self.idc is used to filter out some classes of the target mask. Use fancy indexing
# self.idc: List[int] = kwargs["idc"]
# pc = preds[:, self.idc, ...].type(torch.float32)
# dc = bound[:, self.idc, ...].type(torch.float32)
if preds.ndim == 5:
multipled = torch.einsum("bcxyz,bcxyz->bcxyz", preds, bound)
else:
multipled = torch.einsum("bcwh,bcwh->bcwh", preds, bound)
bd_loss = torch.mean(multipled, dim=self.axis)
if self.reduction == 'mean':
bd_loss = bd_loss.mean()
return bd_loss
if self.reduction == 'sum':
bd_loss = bd_loss.sum()
return bd_loss
if self.reduction == 'none':
return bd_loss
raise Exception('Unexpected reduction {}'.format(self.reduction))