Multiple image inputs dataloader

Hi All,

I’m trying to create a dataloader with multiple image inputs (different resolutions related to each other). I’m not sure if I am going in the right direction. Essentially I have three related images, where they are stored in a data structure like this:

dataset — train
|------ 40x-----*png
|-------20x-----*png
|-------10x-----*png

However, I think perhaps I have an issue with indexing?

My code looks like this:

import os
import glob
import numpy as np
import random
from PIL import Image
from torch.utils import data
from torch.utils.data.dataset import Dataset
from torchvision import transforms as T


class Patches(data.Dataset):
    def __init__(self, root, phase):
        self.phase = phase

        imgs_40x = []
        imgs_20x = []
        imgs_5x = []

        if phase == 'train':
            for path, subdirs, files in os.walk(root):
                for sd in subdirs:

                    if sd == "20x":
                        twenty = os.path.join(path,sd)
                        for path, subdirs, files in os.walk(twenty):
                            for fn in files:
                                if fn.endswith(".png") == True:
                                    imgs_20x.append(path + "/" + fn)

                    elif sd == "5x":
                        five = os.path.join(path,sd)
                        for path, subdirs, files in os.walk(five):
                            for fn in files:
                                if fn.endswith(".png") == True:
                                    imgs_5x.append(path + "/" + fn)

                    elif sd == "40x":
                        forty = os.path.join(path,sd)
                        for path, subdirs, files in os.walk(forty):
                            for fn in files:
                                if fn.endswith(".png") == True:
                                    imgs_40x.append(path + "/" + fn)

            self.imgs_40x = imgs_40x
            self.imgs_20x = imgs_20x
            self.imgs_5x = imgs_5x


        elif phase == 'val' or 'test':
            for path, subdirs, files in os.walk(root):
                for sd in subdirs:

                    if sd == "20x":
                        twenty = os.path.join(path,sd)
                        for path, subdirs, files in os.walk(twenty):
                            for fn in files:
                                if fn.endswith(".png") == True:
                                    imgs_20x.append(path + "/" + fn)

                    elif sd == "5x":
                        five = os.path.join(path,sd)
                        for path, subdirs, files in os.walk(five):
                            for fn in files:
                                if fn.endswith(".png") == True:
                                    imgs_5x.append(path + "/" + fn)

                    elif sd == "40x":
                        forty = os.path.join(path,sd)
                        for path, subdirs, files in os.walk(forty):
                            for fn in files:
                                if fn.endswith(".png") == True:
                                    imgs_40x.append(path + "/" + fn)

            self.imgs_40x = imgs_40x
            self.imgs_20x = imgs_20x
            self.imgs_5x = imgs_5x

        if self.phase == 'train':
            self.transforms = T.Compose([T.RandomResizedCrop(224),T.ToTensor(),
                                    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

        else:
            self.transforms = T.Compose([T.RandomResizedCrop(224),T.ToTensor(),
                                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


    def __getitem__(self, index):
        if self.phase == 'train':
            path_40x = self.imgs_40x[index]
            path_20x = self.imgs_20x[index]
            path_5x = self.imgs_5x[index]
            data_40 = Image.open(path_40x).convert('RGB')
            data_40 = self.transforms(data_40)
            data_20 = Image.open(path_20x).convert('RGB')
            data_20 = self.transforms(data_20)
            data_5 = Image.open(path_5x).convert('RGB')
            data_5 = self.transforms(data_5)

            label = int(path_40x.split('/')[2])

        elif self.phase == 'val' or 'test':
            path_40x = self.imgs_40x[index]
            path_20x = self.imgs_20x[index]
            path_5x = self.imgs_5x[index]
            data_40 = Image.open(path_40x).convert('RGB')
            data_40 = self.transforms(data_40)
            data_20 = Image.open(path_20x).convert('RGB')
            data_20 = self.transforms(data_20)
            data_5 = Image.open(path_5x).convert('RGB')
            data_5 = self.transforms(data_5)

            label = int(path_40x.split('/')[2])

        return data_40, data_20, data_5, path_40x, path_20x, path_5x, label


    def __len__(self):
        return len(self.imgs_40x)


if __name__ == '__main__':
    root = 'mults/'
    train_dataset = Patches(root=root+ "train/", phase= 'train')
    trainloader = data.DataLoader(train_dataset, batch_size=5)
    print(len(trainloader))
    for i, (data_40, data_20, data_5, path_40x, path_20x, path_5x, label) in enumerate(trainloader):
        print(data_40)

Not really sure if I need to rewrite this entirely because it returns 0 for the length of the trainloader…

Hi,

What is the problem? Your dataset looks mostly good.
The reason why you might not get matching images is because you walk each subdirectory one by one and the order the images are returned is very OS/directory dependent. You might have to be more careful. Maybe iterating one directory (the 20x) and reconstructing the path of the images in the others directories that match the one you are currently reading in this one.

1 Like

Hi albanD,

Thanks for the speedy reply.

Yes, so I’ve amended the dataloader by reconstructing the paths (not sure if there is a more efficient/better way of doing it) and it appears to work so thanks!

Here is my code:

import os
import glob
import numpy as np
import random
from PIL import Image
from torch.utils import data
from torch.utils.data.dataset import Dataset
from torchvision import transforms as T


class Patches(data.Dataset):
    def __init__(self, root, phase):
        self.phase = phase

        imgs_40x = []
        imgs_20x = []
        imgs_5x = []

        if phase == 'train':
            for path, subdirs, files in os.walk(root):
                for sd in subdirs:

                    if sd == "20x":
                        twenty = os.path.join(path,sd)
                        for path, subdirs, files in os.walk(twenty):
                            for fn in files:
                                if fn.endswith(".png") == True:
                                    imgs_20x.append(path + "/" + fn)
                                    imgs_40x.append(path.strip("20x") + "40x/" + "40x" + fn.strip("20x"))
                                    imgs_5x.append(path.strip("20x") + "5x/" + "5x" + fn.strip("20x"))

            self.imgs_40x = imgs_40x
            self.imgs_20x = imgs_20x
            self.imgs_5x = imgs_5x


        elif phase == 'val' or 'test':
            for path, subdirs, files in os.walk(root):
                for sd in subdirs:

                        if sd == "20x":
                            twenty = os.path.join(path,sd)
                            for path, subdirs, files in os.walk(twenty):
                                for fn in files:
                                    if fn.endswith(".png") == True:
                                        imgs_20x.append(path + "/" + fn)
                                        imgs_40x.append(path.strip("20x") + "40x/" +"40x" + fn.strip("20x")
                                        imgs_5x.append(path.strip("20x") + "5x/" + "5x" + fn.strip("20x")



            self.imgs_40x = imgs_40x
            self.imgs_20x = imgs_20x
            self.imgs_5x = imgs_5x

        if self.phase == 'train':
            self.transforms = T.Compose([T.RandomResizedCrop(224),T.ToTensor(),
                                    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

        else:
            self.transforms = T.Compose([T.RandomResizedCrop(224),T.ToTensor(),
                                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


    def __getitem__(self, index):
        if self.phase == 'train':
            path_40x = self.imgs_40x[index]
            path_20x = self.imgs_20x[index]
            path_5x = self.imgs_5x[index]
            data_40 = Image.open(path_40x).convert('RGB')
            data_40 = self.transforms(data_40)
            data_20 = Image.open(path_20x).convert('RGB')
            data_20 = self.transforms(data_20)
            data_5 = Image.open(path_5x).convert('RGB')
            data_5 = self.transforms(data_5)

            label = int(path_40x.split('/')[2])

        elif self.phase == 'val' or 'test':
            path_40x = self.imgs_40x[index]
            path_20x = self.imgs_20x[index]
            path_5x = self.imgs_5x[index]
            data_40 = Image.open(path_40x).convert('RGB')
            data_40 = self.transforms(data_40)
            data_20 = Image.open(path_20x).convert('RGB')
            data_20 = self.transforms(data_20)
            data_5 = Image.open(path_5x).convert('RGB')
            data_5 = self.transforms(data_5)

            label = int(path_40x.split('/')[2])

        return data_40, data_20, data_5, path_40x, path_20x, path_5x, label


    def __len__(self):
        return len(self.imgs_40x)


if __name__ == '__main__':
    root = 'mults/'
    train_dataset = Patches(root=root+ "train/", phase= 'train')
    trainloader = data.DataLoader(train_dataset, batch_size=5)
    print(len(trainloader))
    for i, (data_40, data_20, data_5, path_40x, path_20x, path_5x, label) in enumerate(trainloader):
        print(path_40x[0], path_20x[0], path_5x[0])```

And here is the output from the first batch:

7
mults/train/0/5678.ndpi/40x/40x-236247-16635-80640-8704.png mults/train/0/5678.ndpi/20x/20x-236247-16635-80640-8704.png mults/train/0/5678.ndpi/5x/5x-236247-16635-80640-8704.png
mults/train/0/10001.ndpi/40x/40x-236247-10142-15872-5376.png mults/train/0/10001.ndpi/20x/20x-236247-10142-15872-5376.png mults/train/0/10001.ndpi/5x/5x-236247-10142-15872-5376.png
mults/train/0/10001.ndpi/40x/40x-236247-10151-18176-5376.png mults/train/0/10001.ndpi/20x/20x-236247-10151-18176-5376.png mults/train/0/10001.ndpi/5x/5x-236247-10151-18176-5376.png
mults/train/0/1234.ndpi/40x/40x-236247-16657-86272-8704.png mults/train/0/1234.ndpi/20x/20x-236247-16657-86272-8704.png mults/train/0/1234.ndpi/5x/5x-236247-16657-86272-8704.png
mults/train/1/5678.ndpi/40x/40x-236247-16634-80384-8704.png mults/train/1/5678.ndpi/20x/20x-236247-16634-80384-8704.png mults/train/1/5678.ndpi/5x/5x-236247-16634-80384-8704.png
mults/train/1/10001.ndpi/40x/40x-236247-10152-18432-5376.png mults/train/1/10001.ndpi/20x/20x-236247-10152-18432-5376.png mults/train/1/10001.ndpi/5x/5x-236247-10152-18432-5376.png
mults/train/1/1234.ndpi/40x/40x-236247-16656-86016-8704.png mults/train/1/1234.ndpi/20x/20x-236247-16656-86016-8704.png mults/train/1/1234.ndpi/5x/5x-236247-1665```

Yes that looks good :slight_smile:
If creating your Dataset is too slow, you could consider buffering the work done in the __init__ and dump self.imgs_40x, self.imgs_20x and self.imgs_5x into a text file so that you can simply reload them if they exist in the __init__.

1 Like

Brilliant, I’ll have a look into that. Thanks a lot. :grinning: