How to remove a class from a dataset?

I have an ImageFolder dataset containing many classes. I want to modify the classes without modifying the underlying folder.
Example for such modification - removing a certain class, i.e. training and testing only on classes 2 and 4 out of the existing 5.
I am aware of the torch.utils.data.Subset function, but by selecting the relevant indices this way, many of the dataset fields remain irrelevant: for example if i remove all the indices of a certain class, the dataset.classes field is left at the original length, containing the dropped class.
I thought that maybe this is irrelevant cause I can just adjust my tensors’ shapes to the new number of classes, but some problems remain, e.g. the targets of the classes i am using are not shifted accordingly (if i removed a class, the target indices of the succeeding samples should decrease by 1).
Is there any way to remove a class cleanly?

I think writing a custom Dataset using some ImageFolder methods (and filtering out the unwanted folders) might be the cleanest approach.

1 Like

i know it’s more of a software engineering question, but just to make sure - this is the inheritance:
ImageFolder - DatasetFolder - VisionDataset - Dataset

So I find in which level the reading of the folder is being performed, copy the objects from this level downward, and modify the relevant object?

I would probably try to derive or re-implement DatasetFolder, as this line of code seems to be the interesting one.

1 Like