Using only some classes of a data set

I am following this tutorial:

playing around with some classifiers in pytorch. But suppose I have a training data folder called train and within this train folder I had 4 folders for 4 classes A, B, C and D. Pytorch seems to have a few very convenient functions for loading in data and training on that data using these lines of code in the linked tutorial:

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

How can I modify this data loader so that I only load in a subset of the classes in the data directory? For example if I wanted to only train on class A B and C rather than all 4.

You could either write a custom Dataset (using the ImageFolder as the base class and add your class filtering code to it) or alternatively create a new root folder containing symbolic links to the wanted class folders and use this new folder as the root argument in ImageFolder.

1 Like

ahhh the symbolic links Idea is good! Thanks!

I would also consider going one level above the ImageFolder Class which inherits from DatasetFolder.
DatasetFolder uses a method to index the folder subdirectories for each class:

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

See here: torchvision.datasets.folder — Torchvision 0.10.0 documentation

You may simply create your own DatasetFolder (which inherits from VisionDataset, don’t forget to inherit from that) and then let your own Image Folder class inherit from DatasetFolder (if even needed).
By creating your own DatasetFolder, create a new find_classes method, which only scans for subdirectories in your dir, with your desired class name

def find_classes(directory: str, desired_class_names: List) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset."""
.......

Hope that makes sense and helps! Also just came across this question and tomorrow I am going to solve it!

1 Like