After reading this post and this post I was finally able to come up with a working solution.
Here is the code for my Dataloader:
import numpy as np
import torch
import torch.utils.data
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as tvtransforms
import torch.nn.functional
class MyDataset(Dataset):
def __init__(self, image_root, mask_root, transforms=None):
self.image_root = image_root
self.mask_root = mask_root
self.transforms = transforms
self.img = None
self.msk = None
# mapping instructions to map BGR mask to classes
self.mapping = {(0, 0, 0): 0, # 0 = background
(255, 0, 0): 1, # 1 = class 1
(0, 255, 0): 2, # 2 = class 2
(0, 0, 255): 3} # 3 = class 3
def __len__(self):
return len(self.image_root)
def mask_to_class_rgb(self, mask):
print('----mask->rgb----')
mask = torch.from_numpy(np.array(mask))
mask = torch.squeeze(mask) # remove 1
# check the present values in the mask, 0 and 255 in my case
print('unique values rgb ', torch.unique(mask))
# -> unique values rgb tensor([ 0, 255], dtype=torch.uint8)
class_mask = mask
class_mask = class_mask.permute(2, 0, 1).contiguous()
h, w = class_mask.shape[1], class_mask.shape[2]
mask_out = torch.empty(h, w, dtype=torch.long)
for k in self.mapping:
idx = (class_mask == torch.tensor(k, dtype=torch.uint8).unsqueeze(1).unsqueeze(2))
validx = (idx.sum(0) == 3)
mask_out[validx] = torch.tensor(self.mapping[k], dtype=torch.long)
# check the present values after mapping, in my case 0, 1, 2, 3
print('unique values mapped ', torch.unique(mask_out))
# -> unique values mapped tensor([0, 1, 2, 3])
return mask_out
def __getitem__(self, index):
# load images and masks
img = Image.open(self.image_root[index])
msk = Image.open(self.mask_root[index])
img_new = img
mask_new = msk
# apply data augmentation and tensor transformation
if self.transforms is not None:
# convert pil image to numpy array
image_np = np.array(img)
# print(image_np)
mask_np = np.array(msk)
# apply augmentations → only random flip with Albumentation
augmented = self.transforms(image=image_np, mask=mask_np)
img_new = augmented['image']
mask_new = augmented['mask']
img_new = torch.from_numpy(img_new).float()
else:
img_new = torch.from_numpy(img_new).float()
# normalize image
img_new = img_new.permute(2, 0, 1).contiguous()
norm = tvtransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
img_new = norm(img_new)
# map gray mask to class
# mask_new = self.mask_to_class_gray(mask_new)
mask_new = self.mask_to_class_rgb(mask_new)
mask_new = mask_new.long()
return img_new, mask_new
With this my mask output has the shape [Batch_size, height, width]. Is that fine?