What is an appropriate usage of ImageFolder/DatasetFolder in this situation (images stored in a batch-ed manner)?

Dear folks,

hello, I am a quite newbie in deep learning with PyTorch…

I am trying to build an image classification model using PyTorch,

but in the process of pre-processing dataset, I’ve stuck with a problem.

Problem)

  • I want to read a bunch of .npy files in a directory (e.g. ‘/training’)

  • One of the .npy files consists of tensors with shape of (N, C, H, W), i.e. mini-batch of images

    (e.g. file1: (1592, 3, 224, 224), file2: (1100, 3, 224, 224), …)

  • The size of mini-batch is different from each other (1592, 1100, 683, …)

Questions)

  • AFAIK, DatasetFolder or ImageFolder module of torchvision.datasets are used for the directory having a list of single image files. (e.g. 1592 independent .npy files, of which shape is (1, 3, 224, 224))

  • In my case, is there any recommended way to deal with the situation?
    (I finally want to use DataLoader after constructing a (custom) DatasetFolder/ImageFolder module)

  • I want to avoid the situation of reading each mini-batch .npy file, split them by a sample (1, C, H, W), and save them again in a separate file…

Thank you, in advance.

Have a good day!

There are different way of handling this use case.

The simplest would be to load the complete data into memory and just use some conditions on the index to select the right numpy array where the current sample is located to load it.

The other approach would be to create a separate Dataset for each numpy array.
Based on the length of the underlying numpy arrays you could then create the indices and shuffle them (if shuffling is needed).

Once this is done you could then use Subset to pass these indices and enforce the shuffling, and pass all datasets into ConcatDataset.

Finally wrap the ConcatDataset into a DataLoader and it should make sure to load the data sequentially (but shuffle it inside each dataset).

Here is an example code snippet showing this approach, as the explanation might be a bit vague:

from torch.utils.data import TensorDataset, DataLoader, Subset, ConcatDataset

dataset1 = TensorDataset(torch.arange(10).view(10, 1))
dataset2 = TensorDataset(torch.arange(10, 20).view(10, 1))

idx1 = torch.randperm(len(dataset1))
idx2 = torch.randperm(len(dataset1))

dataset1 = Subset(dataset1, idx1)
dataset2 = Subset(dataset2, idx2)

datasets = ConcatDataset((dataset1, dataset2))

loader = DataLoader(datasets, batch_size=5)
for data in loader:
    print(data)

Note that this approach would also load the complete dataset into memory.
If you want to avoid it, you could create the “next” dataset once the old dataset was used completely.

Let me know, if any approach would work for you.

Sorry for super late reply.

Thanks to your insightful answer, I could deal with all problems I’ve faced!

For dealing with OOM problem, I’ve utilized np.memmap function for loading data directly from the storage.

Here is code snippet:

class PartialDataset(Dataset):
    """
    * Description: custom `Dataset` module for processing `.npy` files (N, C, H, W) (N > 1) grouped by date
    - i.e. mini-batched .npy file stored by date
    - Therefore, the number of samples, 'N', is different from each other...
    """    
    def __init__(self, read_path, date, transform=None):
        """
        * Arguments:
        - read_path (string): path of `.npy` files
        - data (string): date(yymmdd) as a file name
        - transform (callable, optional): optional transform to be applied on a sample
        """
        self.transform = transform
        self.path = read_path
        self.date = date

        self.data = self.read_memmap(f'{os.path.join(self.path, self.date)}.npy')
        
    def read_memmap(self, file_name):
        """
        * Descripton: read `np.memmap` file from the directory
        
        * Argument:
        - file_name (string): path of '.npy' and '.npy.conf' files
        
        * Output:
        - whole data loaded in a memory-efficient manner (np.memmap) 
        """
        with open(file_name + '.conf', 'r') as file:
            memmap_configs = json.load(file) 
            return np.memmap(file_name, mode='r+', shape=tuple(memmap_configs['shape']), dtype=memmap_configs['dtype'])

    def __getitem__(self, index):
        """
        * Description: function for indexing samples
        
        * Argument:
        - index (int): index of the sample
        
        * Output:
        - input data, output data (torch.Tensor, torch.Tensor)
        - (batch_size, 4 (Mask(0 - background, 1 - foreground) / input1 / input2 / input3), height, width), (batch_size, output, height, width)
        """
        
        mask = torch.Tensor(self.data[index, 0, :, :]).reshape(1, PATCH_HEIGHT, PATCH_WIDTH)
        inputs = torch.Tensor(self.data[index, 2:4, :, :])
        output = torch.Tensor(self.data[index, 1, :, :]).reshape(1, PATCH_HEIGHT, PATCH_WIDTH)
            
        if self.transform is not None:
            inputs = self.transform(inputs)
            
        inputs = np.concatenate([mask, inputs], axis=0)
        return (inputs, output)

    def __len__(self):
        """
        * Description: fucntion for noticing the length of dataset
        
        * Output:
        - length (int)
        """
        return self.data.shape[0]

After I read separate *.npy files as a PartialDataset instance, I stored them in the list like below:

def construct_partial_dataset(read_path, test_date=['20190212', '20190612', '20190912', '20191216'], transform=transform):
    test_date = test_date
    training_list, test_list = [], []
    
    for path, dirs, files in os.walk(read_path):
        if dirs != []: continue 
        file_list = sorted(files)
        for file in file_list:
            if '.conf' in file: continue
                
            is_train = False if file[:8] in test_date else True 
            if is_train:
                training_list.append(PartialDataset(read_path=os.path.join(read_path, 'training'), date=file[:8], transform=transform))
            else:
                test_list.append(PartialDataset(read_path=os.path.join(read_path, 'validation'), date=file[:8], transform=transform))
    
    return training_list, test_list

After constructing two lists (training_list and test_list), I’ve followed your advice: ConcatDataset and Subset to construct a final training/test dataset containing whole corresponding arrays.

def concat_list_of_partial_datasets(dataset_list, num_samples=-1):
    """
    * Description: function of concatenating `PartialDataset` object to be a dataset using `torch.utils.data.ConcatDataset`
    
    * Argument:
    - dataset_list (list): list of `PartialDataset` instances
    
    * Output:
    - a dataset (torch.utils.data.ConcatDataset)
    """
    dataset = None
    for i in range(len(dataset_list) - 1):
        if i == 0:   
            idx1 = torch.randperm(len(dataset_list[i]))[:num_samples]
            idx2 = torch.randperm(len(dataset_list[i + 1]))[:num_samples]

            dataset1 = Subset(dataset_list[i], idx1)
            dataset2 = Subset(dataset_list[i + 1], idx2)
            dataset = ConcatDataset((dataset1, dataset2))
        else:
            idx = torch.randperm(len(dataset_list[i + 1]))[:num_samples]
            dataset_next = Subset(dataset_list[i + 1], idx)
            dataset = ConcatDataset((dataset, dataset_next))
    
    return dataset

Finally, I could utilize DataLoader for loading data!

training_set, test_set = concat_list_of_partial_datasets(dataset_list=training_list, num_samples=8192), concat_list_of_partial_datasets(dataset_list=test_list, num_samples=1024)

training_loader = DataLoader(training_set, batch_size=HYPERPARAMS['batch_size'], shuffle=True, num_workers=16)
test_loader = DataLoader(test_set, batch_size=HYPERPARAMS['batch_size'], shuffle=False, num_workers=16)

Since I am not that proficient in Python and PyTorch, still believe there exists more memory-efficient and better way to solve the situation like mine…!
(Please let me know if my code has problems…)

Thank you again for your kind and gentle answer.
Have a good day!