Hello. Guys what are the functions we have for them in PyTorch? It would be a great help. Help required. Both are the lists of the folder contents with images+labels

dataset = tf.data.Dataset.from_tensor_slices((imagepath,labelpath))

    dataset = dataset.map(lambda image_path, labelpath: tuple(tf.py_func(
        self.input_parser, [imagepath, labelpath], [tf.float32,tf.int32])), num_parallel_calls=2)

@ptrblck can you please help?

Are you working on semantic segmentation?

Yes man! Can you kindly help me convert them?

Assuming you have train_images, train_labels (segmentation masks) folders, in PyTorch you need to prepare a Dataset class and a Dataloader that feeds the data into your model for training. Here’s an example of a very simple dataset class and dataloader:

# The Dataset class definition:

import torch
import torchvision.transforms as T
import numpy as np
import os

from PIL import Image
from torch.utils.data import Dataset

class SegmentationDataset(Dataset):

    def __init__(self, images_dir, labels_dir):
        self.images = sorted(os.listdir(images_dir))
        self.labels = sorted(os.listdir(labels_dir))
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transform = T.Compose([

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):

        img_path = self.images_dir + '/' + self.images[i]
        msk_path = self.labels_dir + '/' + self.labels[i]

        img = Image.open(img_path).convert('RGB')
        lbl = Image.open(msk_path).convert('L')

        img = img.resize((512, 512))
        lbl = lbl.resize((512, 512))

        image = self.transform(img)
        label = torch.from_numpy(np.array(lbl)).long()

        return image, label

The Dataloader:

train_dataset = SegmentationDataset('/path/to/train/images/', '/path/to/train/labels/')
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4)

# You could write a similar code for the validation

Model training:

for epoch in range(50):
    for idx, data in enumerate(train_dataloader):
        images, labels = data[0], data[1]
        # Rest of the training code
1 Like

Any suggestions on 3D dataset? I guess PIL will not work

What is the format of your dataset? I think you should be able to make the necessary changes in the dataset class. Instead of loading the images/labels from the directory, load the 3d dataset (np.array() or .nii) and then crop the no. of slices accordingly depending upon your memory.