How to sample images belonging to particular classes

I have an ImageFolder. I see that the __getitem__(index) method of ImageFolder can get both the tensor and the class of an index. But what if I want to use ImageFolder to sample the images from only a particular class of my choice.
How would this be done?

If could provide a WeightedRandomSampler with all weights set to zero for elements of the unwanted classes.
Alternatively, you could write your custom Dataset and just read all images from a particular folder, which will also result in a single class being sampled.

So currently I am using custom dataset and the class looks like this:

class CustomData(Dataset):
    CustomData dataset

    def __init__(self, name, dirpath, transform=None, should_invert=False):
        super(Dataset, self).__init__()
        self.dirpath = dirpath
        self.imageFolderDataset = dset.ImageFolder(root=self.dirpath)
        self.transform = transform
        self.should_invert = should_invert
        self.to_tensor = transforms.ToTensor()

    def __getitem__(self, index):
        # Training images
        images = self.imageFolderDataset.imgs
        img = cv2.imread(images[index][0])

        if self.should_invert:
            img = PIL.ImageOps.invert(img)

        if self.transform is not None:
            img = self.transform(img)

        img = np.array(img, dtype='uint8')
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_as_tensor = self.to_tensor(img)
        return img_as_tensor, images[index][1]

    def __len__(self):
        return len(self.imageFolderDataset.imgs)

How can I incorporate reading from particular folder because ImageFolder gives error if there is no directory inside?

An easy solution would be to just copy one folder into the root directory, but that would probably include some data moving, which you might want to avoid.

I would suggest to pass the image paths of the single class you want and to just sample from them:

def __init__(...):
    self.image_paths = image_paths  # Should contain a list of image paths of your desired class: e.g. ['./data/class0/img0.png', './data/class0/img1.png', ...]

def __getitem__(self, index):
    img = cv2.imread(self.image_paths[index])
    return img_as_tensor, torch.tensor([0])

But would it not be possible without using passing image_paths. I was kind of avoiding this approach.

In that case, I would just use a SubsetRandomSampler based on the class indices.
Here is a small example getting the class indices for class0 from an ImageFolder dataset and creating the SubsetRandomSampler:

targets = torch.tensor(dataset.targets)
target_idx = (targets==0).nonzero()

sampler =

loader = DataLoader(

for data, target in loader:
    print(target)  # should only print zeros

You could create a member method inside CustomData to return the class indices and pass it to the sampler.


Again I have to specify dataset.targets somewhere. I just want to do something like passing the directory path and then loading images from there. Is it possible to do this without using a list if paths like ImageFolder does it.

Internally ImageFolder creates these paths, so the one approach would be to only have one subfolder inside root containing your desired class images.

Every other approach would make the usage of ImageFolder not really useful, as you would have to filter out the other unwanted classes.
A kind of dirty hack would be to use ImageFolder, and increase the index inside __getitem__ until you sample an image from the class you want.

Given a list classes of classes, I did the following:

def checkfun(args):
  return args.split("/")[-2] in classes and args.endswith(".jpg")

def ___find_classes(self, dir):
  return classes, {c: i for i, c in enumerate(classes)}

torchvision.datasets.ImageFolder._find_classes = ___find_classes
dataset = torchvision.datasets.ImageFolder(root="root_training_dir", is_valid_file=checkfun, )


  • Iv’e created filter function to filter out images with the wrong class.
  • Torch _find_classes function just list directory, before the filtering, so Iv’e just replaced it with a new function. Probably the better thing to do here is to build some context manager to switch back to original function after I done.

I just posted on another thread, but i think this one answers this question as well and may answer it best:

I would also consider going one level above the ImageFolder Class which inherits from DatasetFolder.
DatasetFolder uses a method to index the folder subdirectories for each class:

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    classes = sorted( for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

See here: torchvision.datasets.folder — Torchvision 0.10.0 documentation

You may simply create your own DatasetFolder (which inherits from VisionDataset, don’t forget to inherit from that) and then let your own Image Folder class inherit from DatasetFolder (if even needed).
By creating your own DatasetFolder, create a new find_classes method, which only scans for subdirectories in your dir, with your desired class name

def find_classes(directory: str, desired_class_names: List) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset."""
    classes = sorted( for entry in os.scandir(directory) if entry.is_dir())
    classes = classes [desired_class_names] # TODO: do something like this line! Not tested it yet!
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

Hope that makes sense and helps! Also just came across this question and tomorrow I am going to solve it!

Maybe I should request a pull or contribute to the official pytorch repo, by adding a feature that filters for classes as a parameter in ImageFolder? Does this make any sense @ptrblck ?

I think your approach is great and the right way to implement this filtering.
You could surely create a feature request on GitHub and discuss the suggestion with the code owners.
I don’t know how often this feature of skipping specific folders is used, but maybe having the ability to pass a custom find_classes or make_dataset function might make sense (same as passing a custom loader).

1 Like

@ptrblck thanks for your Evaluation, since I have less experience with submitting such suggestions!

Your suggestion is great and thanks for the feedback! Could you CC me on your feature request in GitHub, too?

@ptrblck Sure I will do that! I am going to create a request soon and tag you on GitHub! Edit: Done! Feature Request for torchvision ImageFolder using/inheriting DatasetFolder · Issue #4633 · pytorch/vision · GitHub

1 Like