Using ImageFolder, random_split with multiple transforms

Hi! I tried your approach, it seems to work with num_workers=0 but I am running into a broken pipe issue if num_workers is >0. I assume its something wrong with my dataset class but not sure what to do about it, any suggestions would be greatly appreciated!

trans = transforms.Compose([transforms.RandomRotation(25),
                              transforms.RandomResizedCrop(224),
                              transforms.RandomHorizontalFlip(),
                              transforms.ToTensor(),
                              transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

transNoAugment = transforms.Compose([transforms.Resize(255), 
                                  transforms.CenterCrop(224),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])


# https://discuss.pytorch.org/t/why-do-we-need-subsets-at-all/49391/7
# adapted from ptrblck post
class MyLazyDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        if self.transform:
            x = self.transform(dataset[index][0])
        else:
            x = dataset[index][0]
        y = dataset[index][1]
        return x, y
    
    def __len__(self):
        return len(dataset)
    
# Load entire dataset once
data_dir = 'C:/datasets/kaggleflower/flowers/'
dataset = datasets.ImageFolder(data_dir)

traindataset = MyLazyDataset(dataset,trans)
valdataset = MyLazyDataset(dataset,transNoAugment)
testdataset = MyLazyDataset(dataset,transNoAugment)

# Create the index splits for training, validation and test
train_size = 0.8
num_train = len(dataset)
indices = list(range(num_train))
split = int(np.floor(train_size * num_train))
split2 = int(np.floor((train_size+(1-train_size)/2) * num_train))
np.random.shuffle(indices)
train_idx, valid_idx, test_idx = indices[:split], indices[split:split2], indices[split2:]

traindata = Subset(traindataset, indices=train_idx)
valdata = Subset(valdataset, indices=valid_idx)
testdata = Subset(testdataset, indices=test_idx)

num_workers = 4
batch_size = 32

trainLoader = torch.utils.data.DataLoader(traindata, batch_size=batch_size, 
                                          num_workers=num_workers, drop_last=True)
valLoader = torch.utils.data.DataLoader(valdata, batch_size=batch_size, 
                                          num_workers=num_workers, drop_last=True)
testLoader = torch.utils.data.DataLoader(testdata, batch_size=batch_size,
                                          num_workers=num_workers, drop_last=True)

for x,y in testLoader:
    print (y)

---------------------------------------------------------------------------
BrokenPipeError                           Traceback (most recent call last)
<ipython-input-7-0c933bc9f14c> in <module>
----> 1 for x,y in testLoader:
      2     print (y)

~\.conda\envs\py37\lib\site-packages\torch\utils\data\dataloader.py in __iter__(self)
    276             return _SingleProcessDataLoaderIter(self)
    277         else:
--> 278             return _MultiProcessingDataLoaderIter(self)
    279 
    280     @property

~\.conda\envs\py37\lib\site-packages\torch\utils\data\dataloader.py in __init__(self, loader)
    680             #     before it starts, and __del__ tries to join but will get:
    681             #     AssertionError: can only join a started process.
--> 682             w.start()
    683             self.index_queues.append(index_queue)
    684             self.workers.append(w)

~\.conda\envs\py37\lib\multiprocessing\process.py in start(self)
    110                'daemonic processes are not allowed to have children'
    111         _cleanup()
--> 112         self._popen = self._Popen(self)
    113         self._sentinel = self._popen.sentinel
    114         # Avoid a refcycle if the target function holds an indirect

~\.conda\envs\py37\lib\multiprocessing\context.py in _Popen(process_obj)
    221     @staticmethod
    222     def _Popen(process_obj):
--> 223         return _default_context.get_context().Process._Popen(process_obj)
    224 
    225 class DefaultContext(BaseContext):

~\.conda\envs\py37\lib\multiprocessing\context.py in _Popen(process_obj)
    320         def _Popen(process_obj):
    321             from .popen_spawn_win32 import Popen
--> 322             return Popen(process_obj)
    323 
    324     class SpawnContext(BaseContext):

~\.conda\envs\py37\lib\multiprocessing\popen_spawn_win32.py in __init__(self, process_obj)
     87             try:
     88                 reduction.dump(prep_data, to_child)
---> 89                 reduction.dump(process_obj, to_child)
     90             finally:
     91                 set_spawning_popen(None)

~\.conda\envs\py37\lib\multiprocessing\reduction.py in dump(obj, file, protocol)
     58 def dump(obj, file, protocol=None):
     59     '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60     ForkingPickler(file, protocol).dump(obj)
     61 
     62 #

BrokenPipeError: [Errno 32] Broken pipe
4 Likes