Using U-net, edges of the image are segmented wrong in a weird way

Hi, I have a vision project with 3D U-net where I need to segment (or parcellate) image to 5 different labels + 1 label for background which is label 0.
I trained it using a mean DSC loss and I’ve seen that although the segmentation looks great, the edges are always segmented as other label. It was always label 3 for example in my case, so I tried Weighted DSC loss where label 3 has weight of 2.0 where other labels has a weight of 1.0 . The result ended up the same but instead of being label 3 in the edges it was label 5 this time.
I tried a different approach with the weighted loss and turned the weight for the background label to 0.5 where all others are 1 and got the same results as with the mean DSC loss.
I don’t really know how to solve it instead of using post-processing which I think might solve this since all other labels in the true_segmentation are somewhere in the middle of the image and not in the edges but I believe it is something that can be fixed by changing the loss in a way and the model will be able to learn it by himself.

I will supply here the code so it will be easier to understand but since it is medical data I can’t add images of the segmentation result here so sorry about it. the predictions probabilities are in dimension 1 (dim 0 is batch size) and in the targets it is the same but instead of probabilities it has 1 on the true label channel and 0 anywhere else. I also checked the data like 10 times to see what could cause this and couldn’t find a problem with the data.

I know it is a very specific problem and long one, so even if you don’t have idea about how to help but still got this far, I really appreciate your time and thanks for trying to help.

class DiceLoss(nn.Module):
    def __init__(self, num_classes):
        super(DiceLoss, self).__init__()
        self.num_classes = num_classes

    def forward(self, predicted, targets, eps=1e-3):
        # targets shape = [-1, 6, 192, 256, 192]
        # predicted shape = [-1, 6, 192, 256, 192]
        dices = []
        for i in range(self.num_classes):
            predicted_this_label = predicted[:, i]
            targets_this_label = targets[:, i]
            intersection = (predicted_this_label * targets_this_label).sum()
            dices.append((eps + 2 * intersection) / (eps + predicted_this_label.sum() + targets_this_label.sum()))
        return 1 - sum(dices) / self.num_classes

class WeightedDiceLoss(nn.Module):
    def __init__(self, num_classes, weights=None):
        super(WeightedDiceLoss, self).__init__()
        self.num_classes = num_classes
        if weights is None:
            self.weights = [1] * num_classes
            self.weights = weights

    def forward(self, predicted, targets, eps=1e-3):
        dices = []
        for i in range(self.num_classes):
            predicted_this_label = predicted[:, i]
            targets_this_label = targets[:, i]

            intersection = (predicted_this_label * targets_this_label).sum()

            # Apply class-specific weight
            weight = self.weights[i]
            dices.append(weight * (eps + 2 * intersection) / (eps + (predicted_this_label.sum()) + targets_this_label.sum()))

        return 1 - sum(dices) / sum(self.weights)

I assume your target is a segmentation map? If so, you could try to create a corresponding weight map increasing the weights for the edges and thus adding a penalty to the training for classifying edge pixels as another wrong class.

Thank you for your response.
Just want to make sure that I understand you correctly:
Do you mean that instead of adding a specific weight for each label, I should assign a weight for each pixel based on its distance from the center of the image (like an inverted vignette, with values between 1-5 for example) that will be taken into account in the intersection part of the DSC loss?
If so, do you have an idea about how can I put this loss in [0,1] (not a necessity for the loss to be in this scale but just curious if it’s possible) scale like the other losses here?

Yes, assuming your target is already representing class indices for each pixel. You could use an edge detector to isolate the borders and based on this create a weight tensor for each pixel increasing the weight for the borders. Standard losses are not bound to 1 as their max. value, but you could of course normalize the weight by its sum if needed.

1 Like

Edges tend to be an issue on UNet image generators, as well. It’s a byproduct of how kernels work.

You could add extra padding around the entire border of your input image, equal to the half of the minimum resolution size of the UNet(so a border of 8, 16 or 32 pixels, depending on your model). Fill that with zeros. And then just remove that much off of the output.