Hello,
I have a dataset with complete scene images in one subdirectory and binary masks of each category in a “masks” directory. There is a single subdirectory for each category. Is there a loader for this type of dataset?
Hello,
I have a dataset with complete scene images in one subdirectory and binary masks of each category in a “masks” directory. There is a single subdirectory for each category. Is there a loader for this type of dataset?
This is what worked for me:
import os
import random
from PIL import Image
import numpy as np
import torch.utils.data as utils_data
class MyDataSet(utils_data.Dataset):
def __init__(self, root_dir, image_dir, mask_dir, label, transform=None):
self.dataset_path = root_dir
self.image_dir = image_dir
self.mask_dir = os.path.join(mask_dir, label)
self.transform = transform
mask_full_path = os.path.join(self.dataset_path, self.mask_dir)
self.mask_file_list = [f for f in listdir(mask_full_path) if isfile(join(mask_full_path, f))]
random.shuffle(self.mask_file_list)
def __getitem__(self, index):
file_name = self.mask_file_list[index].rsplit('.', 1)[0]
img_name = os.path.join(self.dataset_path, self.image_dir, file_name+'.jpg')
mask_name = os.path.join(self.dataset_path, self.mask_dir, self.mask_file_list[index])
image = Image.open(img_name)
mask = Image.open(mask_name)
image = np.array(image)
image = np.rollaxis(image, 2, 0)
image = np.array(image).astype(np.float32)
labels = np.array(mask).astype(np.uint8)
sample = {'image': image, 'labels': labels}
if self.transform:
sample = self.transform(sample)
return sample
You can use it as follows:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
transform_pipeline = transforms.Compose([
#transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
train_data = MyDataSet( args.traindir, args.image_dir, args.mask_dir, args.label)
train_loader = torch.utils.data.DataLoader( train_data,
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
I found these resources helpful: