Using transforms on mask in multi-class Semantic segmentation increases the number of class

Hi,
I am trying to do Semantic Segmentation on the MIT ADE20K dataset in PyTorch. After using transforms on the segmentation mask I found that the number of labels has been increased.
Here is my Custom Dataset.

import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as tf
from torch.utils import data
import random
from PIL import Image
import os

class CustomDataset(data.Dataset):

  def __init__(
      self, 
      mask_folder = '/content/data/ADEChallengeData2016/annotations',
      img_folder = '/content/data/ADEChallengeData2016/images',
      split = 'training',
      is_transforms = False,
      img_size = 512,
      test_mode = False,
      crop_ratio = 0.9
      ):
    
    super(CustomDataset, self).__init__()
    # root contains the 
    self.img_folder = img_folder
    self.mask_folder = mask_folder
    self.is_transforms = is_transforms
    self.split = split
    self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
    self.crop_ratio = crop_ratio

    self.img_files = sorted(os.listdir(os.path.join(img_folder, split)))
    self.mask_files = sorted(os.listdir(os.path.join(mask_folder, split)))
    self.mean = self.mean = np.array([104.00699, 116.66877, 122.67892])
  
  def __getitem__(self, index):

    img_name = self.img_files[index]
    img_path = os.path.join(self.img_folder, self.split, img_name)

    mask_name = self.mask_files[index]
    mask_path = os.path.join(self.mask_folder, self.split, mask_name)

    img = Image.open(img_path).convert("RGB")
    #img = np.array(img, dtype=np.uint8)

    mask = Image.open(mask_path)
    #mask = np.array(mask, dtype=np.int32)

    if self.is_transforms:
      img, mask = self.transform(img, mask)
    
    return img, mask
  
  def __len__(self):
    return len(self.img_files) 

  def transform(self, img, mask):
    #resize
    resize = transforms.Resize(size = self.img_size)
    img = resize(img)
    mask = resize(mask)

    # Random Crop
    i, j, h, w = transforms.RandomCrop.get_params(img, output_size = 
                                                 tuple([int((self.crop_ratio*x)) for x in self.img_size]))
    img = tf.crop(img, i, j, h, w)
    mask = tf.crop(mask, i, j, h, w)

    # Random horizontal flipping
    if random.random() > 0.5:
      img = tf.hflip(img)
      mask = tf.hflip(mask)
    # Random vertical Flipping
    if random.random() > 0.5:
      img = tf.vflip(img)
      mask = tf.vflip(mask)
    #Transform to tensor
    img = np.array(img)
    mask = np.array(mask)
    img = torch.from_numpy(img).float()
    mask = torch.from_numpy(mask).long()
    return img, mask
dst_adek = CustomDataset(is_transforms=True)
img, mask = dst_adek[0]
#labels in the mask
print(sorted(np.unique(np.array(mask))))

After executing above code I get following sets of labels:

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 52, 53, 55, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 75, 77, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 95, 96, 97, 98, 100, 102, 103, 104, 105, 108, 109, 110, 113, 117, 120, 121, 123, 124, 126, 129, 130, 139, 150]

This is basically all pixels from 0 to 150.

While when is_transforms = False,

dst_adek = CustomDataset()
img, mask = dst_adek[0]
mask = np.array(mask)
#plt.imshow(label2rgb(mask))
print(np.unique(mask))

I get these sets of labels:

[  0   1   4   5   6  13  18  32  33  43  44  88  97 105 126 139 150]

So It is not good to use transforms on the mask and only use that on the images?

1 Like

I got the answer and really sorry for posting the doubt without doing more work which I should have done.
Initially, I thought that converting images to tensor is distorting my mask, which is not because I am using torch.from_numpy. It was because of resizing. It seems default interpolation in transforms.Resize() is Image.BILINEAR, which in turn is changing the pixel values of masks.
Changing the line

resize = transforms.Resize(size = self.img_size)

to

resize = transforms.Resize(size = self.img_size, interpolation=Image.NEAREST)

solves my problem.

I am again sorry for this silly doubt but I think it might be helpful for someone else in the community.

3 Likes

Mask/Faster R-CNN use torch.functional.interpolate under the hood, images with bilinear and masks with nearest method, you should try that.