Semantic segmentation: How to map RGB mask in data loader

Hi,

I have a question concerning loading and mapping a RGB mask image for semantic segmentation (using U-Net). The mask data consits of RGB images with the same resolution as the original RGB images. There are 4 classes. In RGB color space, class 1 is red (255,0,0), class 2 is green (0,255,0), class 3 is blue (0,0,255) and class 4, the background, is black (0,0,0). I want to map the RGB-Values of the mask to the class values in the data loader.

Example - using the indexes 1, 2, 3, and 4 (not starting with 0):
A ‘pixel’ with value (0,255,0), which is class 2, gets the new value of (0, 2, 0)

I’ve read this example here and could successfully reproduce the mapping for gray-scale mask images.

How is it done with RGB-values? I’m new to Python, coming from C++, and the syntax for self.mapping and the for-loop gives me a headache. I’ve read this and this post, but couldn’t get them to work for my case. Any help is appreciated!

Here is my working code for the data loader using grayscale images. For transformations I use Albumentations, but only horizontal flip, no ToTensor or normalization:

import numpy as np
import torch
import torch.utils.data
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as tvtransforms
import torch.nn.functional


class MyDataset(Dataset):
  def __init__(self, image_root, mask_root, transforms=None):
       self.image_root = image_root
       self.mask_root = mask_root
       self.transforms = transforms
       self.img = None
       self.msk = None       
       self.mapping = {
           0: 0,
           80: 1,
           160: 2,
           240: 3
       }     

   def __len__(self):
       return len(self.image_root)

   def mask_to_class_gray(self, mask):
       mask = torch.from_numpy(np.array(mask))
       mask = torch.squeeze(mask)  # remove 1    
       for k in self.mapping:
           mask[mask == k] = self.mapping[k]
      
       return mask
   
   def __getitem__(self, index):
       # load images and masks
       img = Image.open(self.image_root[index])
       msk = Image.open(self.mask_root[index])

       img_new = img
       mask_new = msk

       # apply data augmentation and tensor transformation
       if self.transforms is not None:
           # convert pil image to numpy array
           image_np = np.array(img)           
           mask_np = np.array(msk)
           # apply augmentations → only random flip with Albumentation
           augmented = self.transforms(image=image_np, mask=mask_np) 
           img_new = augmented['image']
           mask_new = augmented['mask']
           img_new = torch.from_numpy(img_new).float()
       else:
           img_new = torch.from_numpy(img_new).float()

       # normalize image
       img_new = img_new.permute(2, 0, 1).contiguous()
       norm = tvtransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
       img_new = norm(img_new)

       # map gray mask to class
       mask_new = self.mask_to_class_gray(mask_new)
       mask_new = mask_new.long()    

       return img_new, mask_new

After reading this post and this post I was finally able to come up with a working solution.

Here is the code for my Dataloader:

import numpy as np
import torch
import torch.utils.data
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as tvtransforms
import torch.nn.functional


class MyDataset(Dataset):
    def __init__(self, image_root, mask_root, transforms=None):
        self.image_root = image_root
        self.mask_root = mask_root
        self.transforms = transforms
        self.img = None
        self.msk = None      
        # mapping instructions to map BGR mask to classes
        self.mapping = {(0, 0, 0): 0,  # 0 = background
                        (255, 0, 0): 1,  # 1 = class 1
                        (0, 255, 0): 2,  # 2 = class 2
                        (0, 0, 255): 3}  # 3 = class 3

    def __len__(self):
        return len(self.image_root)

    def mask_to_class_rgb(self, mask):
        print('----mask->rgb----')
        mask = torch.from_numpy(np.array(mask))
        mask = torch.squeeze(mask)  # remove 1

        # check the present values in the mask, 0 and 255 in my case
        print('unique values rgb    ', torch.unique(mask)) 
        # -> unique values rgb     tensor([  0, 255], dtype=torch.uint8)

        class_mask = mask
        class_mask = class_mask.permute(2, 0, 1).contiguous()
        h, w = class_mask.shape[1], class_mask.shape[2]
        mask_out = torch.empty(h, w, dtype=torch.long)

        for k in self.mapping:
            idx = (class_mask == torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))         
            validx = (idx.sum(0) == 3)          
            mask_out[validx] = torch.tensor(self.mapping[k], dtype=torch.long)

        # check the present values after mapping, in my case 0, 1, 2, 3
        print('unique values mapped ', torch.unique(mask_out))
        # -> unique values mapped  tensor([0, 1, 2, 3])
       
        return mask_out
    
    def __getitem__(self, index):
        # load images and masks
        img = Image.open(self.image_root[index])
        msk = Image.open(self.mask_root[index])

        img_new = img
        mask_new = msk

        # apply data augmentation and tensor transformation
        if self.transforms is not None:
            # convert pil image to numpy array
            image_np = np.array(img)
            # print(image_np)
            mask_np = np.array(msk)           
            # apply augmentations → only random flip with Albumentation
            augmented = self.transforms(image=image_np, mask=mask_np)
            img_new = augmented['image']
            mask_new = augmented['mask']
            img_new = torch.from_numpy(img_new).float()
        else:
            img_new = torch.from_numpy(img_new).float()

        # normalize image
        img_new = img_new.permute(2, 0, 1).contiguous()
        norm = tvtransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        img_new = norm(img_new)

        # map gray mask to class
        # mask_new = self.mask_to_class_gray(mask_new)
        mask_new = self.mask_to_class_rgb(mask_new)
        mask_new = mask_new.long()        

        return img_new, mask_new

With this my mask output has the shape [Batch_size, height, width]. Is that fine?

Why do you use torch.squeeze() here ? I thought your mask has 3 channels.

Without having the functional code at hand here a fast reply: I think i squeezed the masked there because the shape of the mask was something like [Batch_size, height, width, 1]. Try

print(mask.shape)

to see exact shape. Someone wiser than me might eplain why and how the mask got a fourth entry with 1.

As a remark, the dataloader works as it should in m project.