Multi-stream CNN model, how to load multiple images simutaneously


I want to train a multi-stream CNN model with pytorch. My data looks like this:

Two folder train1 and train2, under these two folders, there are 2 classes, 1 and 2, in each class there are bunch of images. The name of the pictures under train1 and train2 are same.

----- 1:
a1.jpg b1.jpg c1.jpg d1.jpg
a2.jpg b2.jpg c2.jpg d2.jpg

----- 1:
a1.jpg b1.jpg c1.jpg d1.jpg
a2.jpg b2.jpg c2.jpg d2.jpg

For pytorch, how can I input two images from train1 and train2 simutaneously. For example, if we set batch size as 1:
train1/1/a1.jpg and train2/1/a1.jpg should be returned by the data loader at the same time.

How can I implement this, or any other ideas to do this any help will be appreciated

One possible approach would be to use ImageFolder internally in a custom Dataset and replace the root folder in the current path:

class MyDataset(Dataset):
    def __init__(self, train1, transform=None):
        self.dataset1 = datasets.ImageFolder(root=train1)
        self.transform = transform

    def __getitem__(self, index):
        # Get samples from train1
        x1, y = self.dataset1[index]
        # Get corresponding sample from train2
        path = self.dataset1.samples[index][0]
        path = path.replace('train1', 'train2')
        print('Loading ', path)
        x2 = self.dataset1.loader(path)
        if self.transform:
            x1 = self.transform(x1)
            x2 = self.transform(x2)
        return x1, x2, y
    def __len__(self):
        return len(self.dataset1)

dataset = MyDataset(train1='./data/train1')
1 Like

Thanks very much, I will try this solution

Hello @ptrblck

Your solution works well. I want to know how to set batch size in your this customized Dateset class?
for example: I want set batch-size as 5, so that I can get 5 images during the training process?


Wrap the Dataset in a DataLoader, where you can specify the batch_size besides other useful features like shuffling and the usage of multiple workers. Have a look at the docs for more information and maybe the Data loading tutorial for some more information on the data loading pipeline.