Model not working for Multiclass Segmentation

I am training a model for multiclass segmentation problem. I have 3 classes with image size of 512x512 and 1 channel. My dataset has imbalanced classes. The issue is that the model is not performing well for the multiclass segmentation.

I have tried with Cross Entropy Loss, Dice Loss, Jaccard Loss, and combination of Losses (Jaccard + Focal).
Cross Entropy is working fine but results are not satisfactory.

What changes should I made in it?

Below is the code of the model

class AxialDW(nn.Module):
    def __init__(self, dim, mixer_kernel, dilation=1):
        super().__init__()
        h, w = mixer_kernel
        self.dw_h = nn.Conv2d(dim, dim, kernel_size=(h, 1), padding=(max(h // 2, dilation), 0), groups=dim, dilation=dilation)
        self.dw_w = nn.Conv2d(dim, dim, kernel_size=(1, w), padding=(0, max(w // 2, dilation)), groups=dim, dilation=dilation)

    def forward(self, x):
        x = x + self.dw_h(x) + self.dw_w(x)
        return x


class EncoderBlock(nn.Module):
    """Encoding then downsampling"""

    def __init__(self, in_c, out_c, mixer_kernel=(7, 7)):
        super().__init__()
        self.dw = AxialDW(in_c, mixer_kernel=(7, 7))
        self.bn = nn.BatchNorm2d(in_c)
        self.pw = nn.Conv2d(in_c, out_c, kernel_size=1)
        self.down = nn.MaxPool2d((2, 2))
        self.act = nn.GELU()

    def forward(self, x):
        skip = self.bn(self.dw(x))
        x = self.act(self.down(self.pw(skip)))
        return x, skip


class DecoderBlock(nn.Module):
    """Upsampling then decoding"""

    def __init__(self, in_c, out_c, mixer_kernel=(7, 7)):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2)
        self.pw = nn.Conv2d(in_c + out_c, out_c, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_c)
        self.dw = AxialDW(out_c, mixer_kernel=(7, 7))
        self.act = nn.GELU()
        self.pw2 = nn.Conv2d(out_c, out_c, kernel_size=1)

    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.act(self.pw2(self.dw(self.bn(self.pw(x)))))
        return x


class BottleNeckBlock(nn.Module):
    """Axial dilated DW convolution"""

    def __init__(self, dim):
        super().__init__()

        gc = dim // 4
        self.pw1 = nn.Conv2d(dim, gc, kernel_size=1)
        self.dw1 = AxialDW(gc, mixer_kernel=(3, 3), dilation=1)
        self.dw2 = AxialDW(gc, mixer_kernel=(3, 3), dilation=2)
        self.dw3 = AxialDW(gc, mixer_kernel=(3, 3), dilation=3)

        self.bn = nn.BatchNorm2d(4 * gc)
        self.pw2 = nn.Conv2d(4 * gc, dim, kernel_size=1)
        self.act = nn.GELU()

    def forward(self, x):
        x = self.pw1(x)
        x = torch.cat([x, self.dw1(x), self.dw2(x), self.dw3(x)], 1)
        x = self.act(self.pw2(self.bn(x)))
        return x


class ULite(nn.Module):
    def __init__(self, freeze_model, num_classes=3):
        super().__init__()

        """Encoder"""
        self.conv_in = nn.Conv2d(1, 16, kernel_size=7, padding=3)
        self.e1 = EncoderBlock(16, 32)
        self.e2 = EncoderBlock(32, 64)
        self.e3 = EncoderBlock(64, 128)
        self.e4 = EncoderBlock(128, 256)
        self.e5 = EncoderBlock(256, 512)

        """Bottle Neck"""
        self.b5 = BottleNeckBlock(512)

        """Decoder"""
        self.d5 = DecoderBlock(512, 256)
        self.d4 = DecoderBlock(256, 128)
        self.d3 = DecoderBlock(128, 64)
        self.d2 = DecoderBlock(64, 32)
        self.d1 = DecoderBlock(32, 16)
        self.conv_out = nn.Conv2d(16, num_classes, kernel_size=1)

        if freeze_model:
            self.freeze_model()

    def forward(self, x):
        """Encoder"""
        x = self.conv_in(x)
        x, skip1 = self.e1(x)
        x, skip2 = self.e2(x)
        x, skip3 = self.e3(x)
        x, skip4 = self.e4(x)
        x, skip5 = self.e5(x)

        """BottleNeck"""
        x = self.b5(x)  # (512, 8, 8)

        """Decoder"""
        x = self.d5(x, skip5)
        x = self.d4(x, skip4)
        x = self.d3(x, skip3)
        x = self.d2(x, skip2)
        x = self.d1(x, skip1)
        x = self.conv_out(x)

                # Apply softmax for multi-class classification
        x = F.softmax(x, dim=1)
        return x

    def freeze_model(self):
        for name, param in self.named_parameters():
            param.requires_grad = False

Jaccard + Focal Loss
jacc+focalloss

Jaccard Loss
jaccardloss

Remove the F.softmax call if you are using nn.CrossEntropyLoss as this criterion expects raw logits.

@ptrblck As I told earlier, Cross Entropy is working fine but results are not satisfactory. IoU scores for two classes are coming very low. My images are too complex. What should I do now?

crossentropy

Your code contains an error and you are passing probabilities into a criterion which expects raw logits. I’m unsure how

is relevant. Did you fix your code based on my previous post? If so, did you see any improvement? If not, why not?

@ptrblck yes I have tried by removing x = F.softmax(x, dim=1) in my model and then trained it. The train and validation curves I shared in my last post belonged to it.
The scores are almost the same as they were before I removed the code.

Could you describe what kind of issues are you seeing? Is the model overfitting to the majority class? If so, could you describe the class imbalance and did you try to add class weights to the criterion?