Hi, I have a vision project with 3D U-net where I need to segment (or parcellate) image to 5 different labels + 1 label for background which is label 0.
I trained it using a mean DSC loss and I’ve seen that although the segmentation looks great, the edges are always segmented as other label. It was always label 3 for example in my case, so I tried Weighted DSC loss where label 3 has weight of 2.0 where other labels has a weight of 1.0 . The result ended up the same but instead of being label 3 in the edges it was label 5 this time.
I tried a different approach with the weighted loss and turned the weight for the background label to 0.5 where all others are 1 and got the same results as with the mean DSC loss.
I don’t really know how to solve it instead of using post-processing which I think might solve this since all other labels in the true_segmentation are somewhere in the middle of the image and not in the edges but I believe it is something that can be fixed by changing the loss in a way and the model will be able to learn it by himself.
I will supply here the code so it will be easier to understand but since it is medical data I can’t add images of the segmentation result here so sorry about it. the predictions probabilities are in dimension 1 (dim 0 is batch size) and in the targets it is the same but instead of probabilities it has 1 on the true label channel and 0 anywhere else. I also checked the data like 10 times to see what could cause this and couldn’t find a problem with the data.
I know it is a very specific problem and long one, so even if you don’t have idea about how to help but still got this far, I really appreciate your time and thanks for trying to help.
class DiceLoss(nn.Module):
def __init__(self, num_classes):
super(DiceLoss, self).__init__()
self.num_classes = num_classes
def forward(self, predicted, targets, eps=1e-3):
# targets shape = [-1, 6, 192, 256, 192]
# predicted shape = [-1, 6, 192, 256, 192]
dices = []
for i in range(self.num_classes):
predicted_this_label = predicted[:, i]
targets_this_label = targets[:, i]
intersection = (predicted_this_label * targets_this_label).sum()
dices.append((eps + 2 * intersection) / (eps + predicted_this_label.sum() + targets_this_label.sum()))
return 1 - sum(dices) / self.num_classes
class WeightedDiceLoss(nn.Module):
def __init__(self, num_classes, weights=None):
super(WeightedDiceLoss, self).__init__()
self.num_classes = num_classes
if weights is None:
self.weights = [1] * num_classes
else:
self.weights = weights
def forward(self, predicted, targets, eps=1e-3):
dices = []
for i in range(self.num_classes):
predicted_this_label = predicted[:, i]
targets_this_label = targets[:, i]
intersection = (predicted_this_label * targets_this_label).sum()
# Apply class-specific weight
weight = self.weights[i]
dices.append(weight * (eps + 2 * intersection) / (eps + (predicted_this_label.sum()) + targets_this_label.sum()))
return 1 - sum(dices) / sum(self.weights)