[Solved] Loading Images from Masks

Hello,

I have a dataset with complete scene images in one subdirectory and binary masks of each category in a “masks” directory. There is a single subdirectory for each category. Is there a loader for this type of dataset?

1 Like

This is what worked for me:

import os
import random
from PIL import Image
import numpy as np
import torch.utils.data as utils_data

class MyDataSet(utils_data.Dataset):
  def __init__(self, root_dir, image_dir, mask_dir, label, transform=None):
      self.dataset_path = root_dir
      self.image_dir = image_dir
      self.mask_dir = os.path.join(mask_dir, label)
      self.transform = transform
      mask_full_path = os.path.join(self.dataset_path, self.mask_dir)
      self.mask_file_list = [f for f in listdir(mask_full_path) if isfile(join(mask_full_path, f))]
      random.shuffle(self.mask_file_list)

  def __getitem__(self, index):
      file_name =  self.mask_file_list[index].rsplit('.', 1)[0]
      img_name = os.path.join(self.dataset_path, self.image_dir, file_name+'.jpg')
      mask_name = os.path.join(self.dataset_path, self.mask_dir, self.mask_file_list[index])
      image = Image.open(img_name)
      mask = Image.open(mask_name)
      image = np.array(image)
      image = np.rollaxis(image, 2, 0)
      image = np.array(image).astype(np.float32)
      labels = np.array(mask).astype(np.uint8)
      sample = {'image': image, 'labels': labels}

      if self.transform:
          sample = self.transform(sample)

      return sample

You can use it as follows:

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

transform_pipeline = transforms.Compose([
            #transforms.RandomSizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            ])

train_data = MyDataSet( args.traindir, args.image_dir, args.mask_dir, args.label)

train_loader = torch.utils.data.DataLoader( train_data,
    batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True)

I found these resources helpful:

3 Likes