Semantic segmentation: How to map RGB mask in data loader

After reading this post and this post I was finally able to come up with a working solution.

Here is the code for my Dataloader:

import numpy as np
import torch
import torch.utils.data
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as tvtransforms
import torch.nn.functional


class MyDataset(Dataset):
    def __init__(self, image_root, mask_root, transforms=None):
        self.image_root = image_root
        self.mask_root = mask_root
        self.transforms = transforms
        self.img = None
        self.msk = None      
        # mapping instructions to map BGR mask to classes
        self.mapping = {(0, 0, 0): 0,  # 0 = background
                        (255, 0, 0): 1,  # 1 = class 1
                        (0, 255, 0): 2,  # 2 = class 2
                        (0, 0, 255): 3}  # 3 = class 3

    def __len__(self):
        return len(self.image_root)

    def mask_to_class_rgb(self, mask):
        print('----mask->rgb----')
        mask = torch.from_numpy(np.array(mask))
        mask = torch.squeeze(mask)  # remove 1

        # check the present values in the mask, 0 and 255 in my case
        print('unique values rgb    ', torch.unique(mask)) 
        # -> unique values rgb     tensor([  0, 255], dtype=torch.uint8)

        class_mask = mask
        class_mask = class_mask.permute(2, 0, 1).contiguous()
        h, w = class_mask.shape[1], class_mask.shape[2]
        mask_out = torch.empty(h, w, dtype=torch.long)

        for k in self.mapping:
            idx = (class_mask == torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))         
            validx = (idx.sum(0) == 3)          
            mask_out[validx] = torch.tensor(self.mapping[k], dtype=torch.long)

        # check the present values after mapping, in my case 0, 1, 2, 3
        print('unique values mapped ', torch.unique(mask_out))
        # -> unique values mapped  tensor([0, 1, 2, 3])
       
        return mask_out
    
    def __getitem__(self, index):
        # load images and masks
        img = Image.open(self.image_root[index])
        msk = Image.open(self.mask_root[index])

        img_new = img
        mask_new = msk

        # apply data augmentation and tensor transformation
        if self.transforms is not None:
            # convert pil image to numpy array
            image_np = np.array(img)
            # print(image_np)
            mask_np = np.array(msk)           
            # apply augmentations → only random flip with Albumentation
            augmented = self.transforms(image=image_np, mask=mask_np)
            img_new = augmented['image']
            mask_new = augmented['mask']
            img_new = torch.from_numpy(img_new).float()
        else:
            img_new = torch.from_numpy(img_new).float()

        # normalize image
        img_new = img_new.permute(2, 0, 1).contiguous()
        norm = tvtransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        img_new = norm(img_new)

        # map gray mask to class
        # mask_new = self.mask_to_class_gray(mask_new)
        mask_new = self.mask_to_class_rgb(mask_new)
        mask_new = mask_new.long()        

        return img_new, mask_new

With this my mask output has the shape [Batch_size, height, width]. Is that fine?