Hi,
I am trying to do Semantic Segmentation on the MIT ADE20K dataset in PyTorch. After using transforms on the segmentation mask I found that the number of labels has been increased.
Here is my Custom Dataset.
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as tf
from torch.utils import data
import random
from PIL import Image
import os
class CustomDataset(data.Dataset):
def __init__(
self,
mask_folder = '/content/data/ADEChallengeData2016/annotations',
img_folder = '/content/data/ADEChallengeData2016/images',
split = 'training',
is_transforms = False,
img_size = 512,
test_mode = False,
crop_ratio = 0.9
):
super(CustomDataset, self).__init__()
# root contains the
self.img_folder = img_folder
self.mask_folder = mask_folder
self.is_transforms = is_transforms
self.split = split
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.crop_ratio = crop_ratio
self.img_files = sorted(os.listdir(os.path.join(img_folder, split)))
self.mask_files = sorted(os.listdir(os.path.join(mask_folder, split)))
self.mean = self.mean = np.array([104.00699, 116.66877, 122.67892])
def __getitem__(self, index):
img_name = self.img_files[index]
img_path = os.path.join(self.img_folder, self.split, img_name)
mask_name = self.mask_files[index]
mask_path = os.path.join(self.mask_folder, self.split, mask_name)
img = Image.open(img_path).convert("RGB")
#img = np.array(img, dtype=np.uint8)
mask = Image.open(mask_path)
#mask = np.array(mask, dtype=np.int32)
if self.is_transforms:
img, mask = self.transform(img, mask)
return img, mask
def __len__(self):
return len(self.img_files)
def transform(self, img, mask):
#resize
resize = transforms.Resize(size = self.img_size)
img = resize(img)
mask = resize(mask)
# Random Crop
i, j, h, w = transforms.RandomCrop.get_params(img, output_size =
tuple([int((self.crop_ratio*x)) for x in self.img_size]))
img = tf.crop(img, i, j, h, w)
mask = tf.crop(mask, i, j, h, w)
# Random horizontal flipping
if random.random() > 0.5:
img = tf.hflip(img)
mask = tf.hflip(mask)
# Random vertical Flipping
if random.random() > 0.5:
img = tf.vflip(img)
mask = tf.vflip(mask)
#Transform to tensor
img = np.array(img)
mask = np.array(mask)
img = torch.from_numpy(img).float()
mask = torch.from_numpy(mask).long()
return img, mask
dst_adek = CustomDataset(is_transforms=True)
img, mask = dst_adek[0]
#labels in the mask
print(sorted(np.unique(np.array(mask))))
After executing above code I get following sets of labels:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 50, 51, 52, 53, 55, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 71, 72, 73, 74, 75, 77, 80, 81, 82, 83, 85, 86, 87, 88, 89, 90, 95, 96, 97, 98, 100, 102, 103, 104, 105, 108, 109, 110, 113, 117, 120, 121, 123, 124, 126, 129, 130, 139, 150]
This is basically all pixels from 0 to 150.
While when is_transforms = False
,
dst_adek = CustomDataset()
img, mask = dst_adek[0]
mask = np.array(mask)
#plt.imshow(label2rgb(mask))
print(np.unique(mask))
I get these sets of labels:
[ 0 1 4 5 6 13 18 32 33 43 44 88 97 105 126 139 150]
So It is not good to use transforms on the mask and only use that on the images?