Using ImageFolder, random_split with multiple transforms

Folks, I downloaded the flower’s dataset (images of 5 classes) which I load with ImageFolder. I then split the entire dataset using torch.utils.data.random_split into a training, validation and a testing set.

The issue I am finding is that I have two different transforms I want to apply. One for training which has data augmentation, another for validation and testing which does not.

Question: what is the best way to apply the two different transforms to the 3 datasets? Unfortunately it wont work to pass the transform to ImageFolder as it will do the same transform on all images/

I found split-folders that will split the dataset into training, testing and validation (https://pypi.org/project/split-folders) and then I guess I could use three different calls to ImageFolder to build the datasets with each of their transforms. Is there a better way by just using a single call to ImageFolder?

Thanks in advance for any help on this one!
Jacob

 data_transform_train = transforms.Compose([
                                        transforms.RandomRotation(30),
                                        transforms.RandomResizedCrop(224),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# Validation and Testing - just resize and crop the images
data_transform = transforms.Compose([
                                     transforms.Resize(255), 
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
dataset = datasets.ImageFolder(data_dir, transform=data_transform_train)
train, val, test = torch.utils.data.random_split(dataset, [3459, 432, 432])
trainLoader = torch.utils.data.DataLoader(train, batch_size=batch_size, 
                                           num_workers=num_workers, drop_last=True, shuffle=True)
valLoader = torch.utils.data.DataLoader(val, batch_size=batch_size, 
                                          num_workers=num_workers, drop_last=True)
testLoader = torch.utils.data.DataLoader(test, batch_size=batch_size, 
                                          num_workers=num_workers, drop_last=True)
1 Like

Since ImageFolder will lazily load the data in its __getitem__ method, you could create three different dataset instances for training, validation, and test and could pass the appropriate transformation to them.

You could then create the sample indices via torch.arange(nb_samples) (or the numpy equivalent) and either split these indices manually or with e.g. a sklearn method, which also allows you to apply stratified splits etc.

These indices can then be passed together with a dataset to a Subset instance to create the final datasets.
The passed indices will be used to load the samples, and since you won’t load any data beforehand, you won’t waste memory.

Alternatively, you could also pass the indices to RandomSubsetSampler and pass the datasets to a DataLoader instance.

Let me know, if that would work.

Thank you so much! I will try these and get back!

Hi! I tried your approach, it seems to work with num_workers=0 but I am running into a broken pipe issue if num_workers is >0. I assume its something wrong with my dataset class but not sure what to do about it, any suggestions would be greatly appreciated!

trans = transforms.Compose([transforms.RandomRotation(25),
                              transforms.RandomResizedCrop(224),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                              transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

transNoAugment = transforms.Compose([transforms.Resize(255), 
                                  transforms.CenterCrop(224),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])


# https://discuss.pytorch.org/t/why-do-we-need-subsets-at-all/49391/7
# adapted from ptrblck post
class MyLazyDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        if self.transform:
            x = self.transform(dataset[index][0])
        else:
            x = dataset[index][0]
        y = dataset[index][1]
        return x, y
    
    def __len__(self):
        return len(dataset)
    
# Load entire dataset once
data_dir = 'C:/datasets/kaggleflower/flowers/'
dataset = datasets.ImageFolder(data_dir)

traindataset = MyLazyDataset(dataset,trans)
valdataset = MyLazyDataset(dataset,transNoAugment)
testdataset = MyLazyDataset(dataset,transNoAugment)

# Create the index splits for training, validation and test
train_size = 0.8
num_train = len(dataset)
indices = list(range(num_train))
split = int(np.floor(train_size * num_train))
split2 = int(np.floor((train_size+(1-train_size)/2) * num_train))
np.random.shuffle(indices)
train_idx, valid_idx, test_idx = indices[:split], indices[split:split2], indices[split2:]

traindata = Subset(traindataset, indices=train_idx)
valdata = Subset(valdataset, indices=valid_idx)
testdata = Subset(testdataset, indices=test_idx)

num_workers = 4
batch_size = 32

trainLoader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, 
                                          num_workers=num_workers, drop_last=True)
valLoader = torch.utils.data.DataLoader(valdata, batch_size=batch_size, 
                                          num_workers=num_workers, drop_last=True)
testLoader = torch.utils.data.DataLoader(testdata, batch_size=batch_size,
                                          num_workers=num_workers, drop_last=True)

for x,y in testLoader:
    print (y)

---------------------------------------------------------------------------
BrokenPipeError                           Traceback (most recent call last)
<ipython-input-7-0c933bc9f14c> in <module>
----> 1 for x,y in testLoader:
      2     print (y)

~\.conda\envs\py37\lib\site-packages\torch\utils\data\dataloader.py in __iter__(self)
    276             return _SingleProcessDataLoaderIter(self)
    277         else:
--> 278             return _MultiProcessingDataLoaderIter(self)
    279 
    280     @property

~\.conda\envs\py37\lib\site-packages\torch\utils\data\dataloader.py in __init__(self, loader)
    680             #     before it starts, and __del__ tries to join but will get:
    681             #     AssertionError: can only join a started process.
--> 682             w.start()
    683             self.index_queues.append(index_queue)
    684             self.workers.append(w)

~\.conda\envs\py37\lib\multiprocessing\process.py in start(self)
    110                'daemonic processes are not allowed to have children'
    111         _cleanup()
--> 112         self._popen = self._Popen(self)
    113         self._sentinel = self._popen.sentinel
    114         # Avoid a refcycle if the target function holds an indirect

~\.conda\envs\py37\lib\multiprocessing\context.py in _Popen(process_obj)
    221     @staticmethod
    222     def _Popen(process_obj):
--> 223         return _default_context.get_context().Process._Popen(process_obj)
    224 
    225 class DefaultContext(BaseContext):

~\.conda\envs\py37\lib\multiprocessing\context.py in _Popen(process_obj)
    320         def _Popen(process_obj):
    321             from .popen_spawn_win32 import Popen
--> 322             return Popen(process_obj)
    323 
    324     class SpawnContext(BaseContext):

~\.conda\envs\py37\lib\multiprocessing\popen_spawn_win32.py in __init__(self, process_obj)
     87             try:
     88                 reduction.dump(prep_data, to_child)
---> 89                 reduction.dump(process_obj, to_child)
     90             finally:
     91                 set_spawning_popen(None)

~\.conda\envs\py37\lib\multiprocessing\reduction.py in dump(obj, file, protocol)
     58 def dump(obj, file, protocol=None):
     59     '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60     ForkingPickler(file, protocol).dump(obj)
     61 
     62 #

BrokenPipeError: [Errno 32] Broken pipe
2 Likes

OK, more digging and I found that this resolved the broken pipe issue when num_workers>0:


num_workers = 4
batch_size = 32

if __name__ == '__main__':
    trainLoader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, 
                                              num_workers=num_workers, drop_last=True)
    valLoader = torch.utils.data.DataLoader(valdata, batch_size=batch_size, 
                                            num_workers=num_workers, drop_last=True)
    testLoader = torch.utils.data.DataLoader(testdata, batch_size=batch_size,
                                             num_workers=num_workers, drop_last=True)

Thanks so much for the help with this! If you have any comments / suggestions on the class I modified from one of your posts please let me know!

1 Like

The code looks good and I’m glad you figured out the if-clause protection. :slight_smile:

Great, thanks so much for the help! It’s working well!

I think we can directly splot the dataset from ImageFolder and pass it to data loader this way.

dataset=torchvision.datasets.ImageFolder('path')
train, val, test = torch.utils.data.random_split(dataset, [1009, 250, 250])
traindataset = MyLazyDataset(train,aug)
valdataset = MyLazyDataset(val,aug)
testdataset = MyLazyDataset(test,aug)
num_workers=2
batch_size=6
trainLoader = DataLoader(traindataset , batch_size=batch_size, 
                                           num_workers=num_workers,  shuffle=True)
valLoader = DataLoader(valdataset, batch_size=batch_size, 
                                          num_workers=num_workers )
testLoader = DataLoader(testdataset, batch_size=batch_size, 
                                          num_workers=num_workers)

PS: I am writing this because I found this question useful