ImageFolder dataLoader for ImageNet with selected classes and pretrained PyTorch model

You can create your custom Dataset that returns the expected value corresponding to the original 1000 classes

import torchvision

class MyImageFolder(torchvision.datasets.ImageFolder):
    def __init__(self, img_path, transform=None):
        super(MyImageFolder, self).__init__(img_path, transform)
        self.classes, self.class_to_idx = self._my_classes()
        self.samples = self._make_dataset(self.samples)
        self.imgs = self.samples
        self.targets = [s[1] for s in self.samples]

    def _my_classes(self):
        classes = ['duck', 'wolf']
        class_to_idx = {classes[i]: i for i in range(len(classes))}

        return classes, class_to_idx

    def _make_dataset(self, samples):
        n = len(samples)
        ds = [None] * n
        
        for i, (img, cls) in enumerate(samples):
            ds[i] = (img, self._custom_class(cls))

        return ds

    def _custom_class(self, cls):
        if cls == 0:
            return self.classes[0]
        if cls == 1:
            return self.classes[1]
        else:
            return 'not_my_favorite_class'

This would be a slight variation to the answer given here.

Hope this helps :slight_smile:

1 Like