dataset = tf.data.Dataset.from_tensor_slices((imagepath,labelpath))
dataset = dataset.map(lambda image_path, labelpath: tuple(tf.py_func(
self.input_parser, [imagepath, labelpath], [tf.float32,tf.int32])), num_parallel_calls=2)
Are you working on semantic segmentation?
Yes man! Can you kindly help me convert them?
Assuming you have train_images, train_labels (segmentation masks) folders, in PyTorch you need to prepare a Dataset class and a Dataloader that feeds the data into your model for training. Here’s an example of a very simple dataset class and dataloader:
# The Dataset class definition:
import torch
import torchvision.transforms as T
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset
class SegmentationDataset(Dataset):
def __init__(self, images_dir, labels_dir):
self.images = sorted(os.listdir(images_dir))
self.labels = sorted(os.listdir(labels_dir))
self.images_dir = images_dir
self.labels_dir = labels_dir
self.transform = T.Compose([
T.ToTensor(),
])
def __len__(self):
return len(self.images)
def __getitem__(self, i):
img_path = self.images_dir + '/' + self.images[i]
msk_path = self.labels_dir + '/' + self.labels[i]
img = Image.open(img_path).convert('RGB')
lbl = Image.open(msk_path).convert('L')
img = img.resize((512, 512))
lbl = lbl.resize((512, 512))
image = self.transform(img)
label = torch.from_numpy(np.array(lbl)).long()
return image, label
The Dataloader:
train_dataset = SegmentationDataset('/path/to/train/images/', '/path/to/train/labels/')
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)
# You could write a similar code for the validation
Model training:
for epoch in range(50):
for idx, data in enumerate(train_dataloader):
images, labels = data[0], data[1]
# Rest of the training code
Any suggestions on 3D dataset? I guess PIL will not work
What is the format of your dataset? I think you should be able to make the necessary changes in the dataset class. Instead of loading the images/labels from the directory, load the 3d dataset (np.array()
or .nii
) and then crop the no. of slices accordingly depending upon your memory.