Hi all,
I’m attempting to extend a multi-class 2D dice loss implementation to 3D. Overall, my segmentations learn from a combined loss: DICE loss + Cross Entropy Loss. My 2D segmentation implementation learns from this loss, but my 3D implementation doesn’t seem to be learning anything, so I guess it’s either the dice loss implementation or the cross entropy loss calculation that is causing this.
Here, the 2D snippet ( credit to this discussion: One-hot encoding with autograd (Dice loss) )
def _dice_loss_multichannel(output, target, weights=None, ignore_index=None):
“”"
:param output: NxCxHxW Variable
:param target: NxHxW LongTensor
:param weights: C FloatTensor
:param ignore_index: int index to ignore from loss
:param binary: bool for binarized one chaneel(C=1) input
:return:
"""
eps = 0.0001
encoded_target = output.detach() * 0
if ignore_index is not None:
mask = target == ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0
else:
encoded_target.scatter_(1, target.unsqueeze(1), 1)
if weights is None:
weights = 1
intersection = output * encoded_target
numerator = 2 * intersection.sum(0).sum(1).sum(1)
denominator = output + encoded_target
if ignore_index is not None:
denominator[mask] = 0
denominator = denominator.sum(0).sum(1).sum(1) + eps
loss_per_channel = weights * (1 - (numerator / denominator))
return loss_per_channel.sum() / output.size(1)
It seems to me that extending this to 3D should be as easy as adding a dimension and following the same process:
def _dice_loss_multichannel3D(output, target, weights=None, ignore_index=None):
“”"
Forward pass
:param output: NxCxDxHxW Variable
:param target: NxDxHxW LongTensor
:param weights: C FloatTensor
:param ignore_index: int index to ignore from loss
:param binary: bool for binarized one chaneel(C=1) input
:return:
"""
eps = 0.0001
encoded_target = output.detach() * 0
if ignore_index is not None:
mask = target == ignore_index
target = target.clone()
target[mask] = 0
encoded_target.scatter_(1, target.unsqueeze(1), 1)
mask = mask.unsqueeze(1).expand_as(encoded_target)
encoded_target[mask] = 0
else:
encoded_target.scatter_(1, target.unsqueeze(1), 1)
if weights is None:
weights = 1
intersection = output * encoded_target
numerator = 2 * torch.squeeze(intersection).sum(1).sum(1).sum(1)
denominator = output + encoded_target
if ignore_index is not None:
denominator[mask] = 0
denominator = torch.squeeze(denominator).sum(1).sum(1).sum(1) + eps
loss_per_channel = weights * (1 - (numerator / denominator))
return loss_per_channel.sum() / loss_per_channel.size(0)
If anyone could comment on the validity of this approach, I would appreciate it!
Best,
Chris