Semantic Segmentation: U-net overfits on Pascal VOC 2012

Hello there, So I am doing semantic segmentation on PASCAL VOC 2012. I will show you the fragments of my code:

First of all, this is my VOC classes:

VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

VOC_CLASSES = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']

This is my Dataset

class VOCDataset(torch.utils.data.Dataset):

    """ Pascal VOC2012 Dataset. Read images, apply augmentation and preprocessing transformations.

    Args:
        mode (String) : is it train or validation dataset
        class_rgb_values (list): RGB values of select classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline
            (e.g. crop, resize, etc.)
        preprocessing (albumentations.Compose): data preprocessing
            (e.g. normalization, shape manipulation, etc.)

    """

    def __init__(
            self,
            mode='train',
            class_rgb_values=None,
            augmentation=None,
            preprocessing=None,
    ):

        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing
        self.mode = mode

        # define directory where train and val pictures names are stored
        self.root_path = 'drive/MyDrive/AI/datasets/VOC/VOCdevkit/VOC2012/'
        self.names_path = self.root_path + 'ImageSets/Segmentation/'+ self.mode +'.txt'

        # define image and labels path
        self.image_path = self.root_path + 'JPEGImages/'
        self.label_path = self.root_path + 'SegmentationClass/'

        self.names = []
        self.read_names()

    def __getitem__(self, i):

        # read images and masks
        image = cv2.cvtColor(cv2.imread(self.image_path + self.names[i] + '.jpg'), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.label_path + self.names[i] + '.png'), cv2.COLOR_BGR2RGB)

        # one-hot-encode the mask
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        image = torch.from_numpy(image).to(DEVICE)
        mask = torch.from_numpy(mask).to(DEVICE)
        return image, mask

    def __len__(self):
        # return length of
        return len(self.names)
        
    def read_names(self):
        """
        Read the filenames of training images and labels into self.names
        """
        f = open(self.names_path, 'r')
        line = None
        while 1:
            line = f.readline().replace('\n','')
            if line is None or len(line) == 0 :
                break
            self.names.append(line)
        f.close()

This is augmentations in case you want to see it


def get_training_augmentation():
    train_transform = [
        album.Resize(height=RESIZE_HEIGHT, width=RESIZE_WIDTH), 
        album.OneOf(
            [
                album.HorizontalFlip(p=1),
                album.VerticalFlip(p=1),
                album.RandomRotate90(p=1),
            ],
            p=0.75,
        ),
        album.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
                max_pixel_value=255.0,
            ),
    ]
    return album.Compose(train_transform)


def get_validation_augmentation():
    test_transform = [
        album.Resize(height=256, width=256),
    ]
    return album.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn=None):
    """Construct preprocessing transform
    Args:
        preprocessing_fn (callable): data normalization function
    Return:
        transform: albumentations.Compose
    """
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))

    return album.Compose(_transform)

This is my model imported from smp:

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = VOC_CLASSES
ACTIVATION = 'sigmoid'

# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES),
    activation=ACTIVATION,
).to(DEVICE)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

This is Hyper Parameters:

# Set num of epochs
EPOCHS = 10

# define loss function
loss = smp.utils.losses.DiceLoss()

# define metrics
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

# define optimizer
optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.001),
])

# define learning rate scheduler 
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=1, T_mult=2, eta_min=5e-4,
)

Trainers:

train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

and Training Part:

for i in range(0, EPOCHS):

        # Perform training & validation
        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)
        train_logs_list.append(train_logs)
        valid_logs_list.append(valid_logs)

        # Save model if a better val IoU score is obtained
        if best_iou_score < valid_logs['iou_score']:
            best_iou_score = valid_logs['iou_score']
            torch.save(model, 'drive/MyDrive/AI/voc2012_unet_v1.pth')
            print('Model saved!')

On each iteration, IoU is increasing and Dice Loss is Decreasing which is natural but it Is overfitting on background class.

This is the result:

If I add ignore_channels in DiceLoss(). It doesn’t predict background at all and IoU is very low.
Can you give me any advice?

What i am suspecting is that the data augmentation used is augmenting the source images without applying the same augmentation to its corresponding mask / label.
I would suggest training without the random data augmentation while recording the evolution of the loss function value across consecutive iterations