# Extending multi-class 2D Dice Loss to 3D

Hi all,

I’m attempting to extend a multi-class 2D dice loss implementation to 3D. Overall, my segmentations learn from a combined loss: DICE loss + Cross Entropy Loss. My 2D segmentation implementation learns from this loss, but my 3D implementation doesn’t seem to be learning anything, so I guess it’s either the dice loss implementation or the cross entropy loss calculation that is causing this.

Here, the 2D snippet ( credit to this discussion: One-hot encoding with autograd (Dice loss) )

def _dice_loss_multichannel(output, target, weights=None, ignore_index=None):
“”"

``````    :param output: NxCxHxW Variable
:param target: NxHxW LongTensor
:param weights: C FloatTensor
:param ignore_index: int index to ignore from loss
:param binary: bool for binarized one chaneel(C=1) input
:return:
"""
eps = 0.0001
encoded_target = output.detach() * 0

if ignore_index is not None:
target = target.clone()
encoded_target.scatter_(1, target.unsqueeze(1), 1)
else:
encoded_target.scatter_(1, target.unsqueeze(1), 1)

if weights is None:
weights = 1

intersection = output * encoded_target
numerator = 2 * intersection.sum(0).sum(1).sum(1)
denominator = output + encoded_target

if ignore_index is not None:
denominator = denominator.sum(0).sum(1).sum(1) + eps
loss_per_channel = weights * (1 - (numerator / denominator))

return loss_per_channel.sum() / output.size(1)
``````

It seems to me that extending this to 3D should be as easy as adding a dimension and following the same process:

def _dice_loss_multichannel3D(output, target, weights=None, ignore_index=None):
“”"
Forward pass

``````    :param output: NxCxDxHxW Variable
:param target: NxDxHxW LongTensor
:param weights: C FloatTensor
:param ignore_index: int index to ignore from loss
:param binary: bool for binarized one chaneel(C=1) input
:return:
"""
eps = 0.0001
encoded_target = output.detach() * 0

if ignore_index is not None:
target = target.clone()
encoded_target.scatter_(1, target.unsqueeze(1), 1)
else:
encoded_target.scatter_(1, target.unsqueeze(1), 1)

if weights is None:
weights = 1

intersection = output * encoded_target
numerator = 2 * torch.squeeze(intersection).sum(1).sum(1).sum(1)
denominator = output + encoded_target

if ignore_index is not None:
denominator = torch.squeeze(denominator).sum(1).sum(1).sum(1) + eps
loss_per_channel = weights * (1 - (numerator / denominator))

return loss_per_channel.sum() / loss_per_channel.size(0)
``````

If anyone could comment on the validity of this approach, I would appreciate it!

Best,

Chris

This should work (and if you put triple backticks “```” at the top and bottom of your code, it will be better formatted).
The general problem of 2d->3d is whether it is prohibitively expensive in terms of memory (or compute), which could be the case if the total number of voxels is very large.

Best regards

Thomas

Hi Tom,

Thanks for your reply. In my implementation, I use a combined loss, dice + cross entropy. I suppose if the dice loss is not causing the problem, I wonder if the CE loss is. According to the documentation, I can call the CE loss in the below manner on volumetric outputs and ground truths. Is this call to cross entropy also okay?

``````class CombinedLoss3D(_Loss):
A combination of dice  and cross entropy loss for volumetric outputs and ground truths

def __init__(self, weight):
super(CombinedLoss3D, self).__init__()
self.cross_entropy_loss = nn.CrossEntropyLoss(weight=weight)
self.dice_loss = DiceLoss()

def forward(self, output, ground_truth, weight=None):
"""
Forward pass

:param output: torch.tensor (NxCxDxHxW) Network output (logits) not normalized.
:param ground_truth: torch.tensor (NxDxHxW)
:param weight: torch.tensor (N)
:return: scalar
"""
y_2 = self.dice_loss(output, ground_truth, weight)
y_1 = self.cross_entropy_loss.forward(output, ground_truth)

return y_1 + y_2
``````

Thank you!

Hi @cwat_tum . I am trying to implement `dice loss` for 3D segmentation task. I tried using the loss function defined here.

However, it give me size mismatch error because in some of the in volumes, one of the class is missing.I think either his function is only for Binary case, or it requires `pred` and `target` tensors to have the same shape.

I was wondering if you figured it out for the `multi-class` case and can help me out. Thank you.