Using only some classes of a data set

I am following this tutorial:

playing around with some classifiers in pytorch. But suppose I have a training data folder called train and within this train folder I had 4 folders for 4 classes A, B, C and D. Pytorch seems to have a few very convenient functions for loading in data and training on that data using these lines of code in the linked tutorial:

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

How can I modify this data loader so that I only load in a subset of the classes in the data directory? For example if I wanted to only train on class A B and C rather than all 4.

You could either write a custom Dataset (using the ImageFolder as the base class and add your class filtering code to it) or alternatively create a new root folder containing symbolic links to the wanted class folders and use this new folder as the root argument in ImageFolder.

1 Like

ahhh the symbolic links Idea is good! Thanks!