Unet+CityScape underfitting(-ish)

Hi!

I’m having a problem with my Unet implementation, it seems that the network is unable to overfit on the training data (the CE loss plateaus at ~3.1 and the per pixel accuracy at 40-50%).
When i view a training image after every epoch on the training data i can see that the model just pretty much does 50% sky 50% road, which i guess hits the local minima for the Cross Entropy loss.
The model trains to 40-50% per pixel acc at the end of epoch 1 and then seems to stay there.

I am using Pytorch default Adam optimizer, i tried using SGD with lr=0.01/0.005/0.001 but the same problem persisted. The model uses 34 classes (CityScape Finely annotated 5000 images).

Here is my Unet implementation

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.down = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.down(x)
        return x


class DownScale(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DownScale, self).__init__()
        self.down_block = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        x = self.down_block(x)
        return x



class UpScale(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(UpScale, self).__init__()

        self.upsample = nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2) 
        self.doubleconv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):

        x1 = self.upsample(x1, output_size=x2.shape)
        assert x1.shape == x2.shape, f"up size:{x1.shape}, previous:{x2.shape}"
        x = torch.cat([x2, x1], dim=1)
        x = self.doubleconv(x)
        return x



class Unet(nn.Module):

    def __init__(self, input_channels, num_classes):
        super(Unet, self).__init__()
        self.inconv = DoubleConv(input_channels, 64)
        self.down_block1 = DownScale(64, 128)
        self.down_block2 = DownScale(128, 256)
        self.down_block3 = DownScale(256, 512)
        self.down_block4 = DownScale(512, 1024)
        self.up_block1 = UpScale(1024, 512)
        self.up_block2 = UpScale(512, 256)
        self.up_block3 = UpScale(256, 128)
        self.up_block4 = UpScale(128, 64)
        self.outputmap = nn.Conv2d(64, num_classes, kernel_size=1)


    def forward(self, x):
        x1 = self.inconv(x)
        x2 = self.down_block1(x1)
        x3 = self.down_block2(x2)
        x4 = self.down_block3(x3)
        x5 = self.down_block4(x4)
        x = self.up_block1(x5, x4)
        x = self.up_block2(x, x3)
        x = self.up_block3(x, x2)
        x = self.up_block4(x, x1)
        x = self.outputmap(x)
        y = nn.Softmax(dim=1)(x)
        return y

Here is an image of the model after the 2nd epoch (it stays pretty much like this even after 3-4 epochs):

However for a different semantic segmentation dataset, the model seemed to learn:

What i have tried:

  • Cross entropy with weights per class (0.1-0.5 for sky/road/building class (the most frequent ones), 1.0 for others), but that didn’t really work.

To me this seems to be a loss function problem which doesn’t get out of the local minima, but i thought that cross entropy should sort of work. I haven’t tried dice loss or any other losses, if you could point out any tips it would be wonderful.

Thanks!