Use dataloader to load images from single class

Hi,
I am looking to build a Gan model using pytorch which requires me to load images of different classes seprerately. I have used the torchvision.datasets.imagefolder to load my dataset. I have images of horses and zebras. Now how can I load the images from only one class of the data. That is, I want to be able to specify to my data loader whether I want images of horses or zebras?

You could extract all indices for the desired class (e.g. horses) and create a torch.utils.data.Subset using these indices and the dataset before passing it to the DataLoader.
Alternatively, you could also pass these indices to a RandomSubsetSampler and assign it to the DataLoader.

Hi,
Thank you for the solution @ptrblck . Using torch.utils.data.Subset did the trick for me.

Here is how I managed to do it.

import torchvision.datasets as dset
from torch.utils.data import DataLoader, Subset

gan_dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([
                                            transforms.Resize(256),
                                            transforms.ToTensor(), 
                                            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))

# the ImageFolder class has the following variables which can be used to find 
# the range of indices for Subset:

print(gan_dataset.class_to_idx)
# gan_dataset.imgs is a list of tuples of (file_path, class_index) for all items in the dataset
print(gan_dataset.imgs)

horse_dset = Subset(gan_dataset, range(260, 1327))
zebra_dset = Subset(gan_dataset, range(1327, 2670))

horse_loader = DataLoader(horse_dset, batch_size=5, shuffle = True)
zebra_loader = DataLoader(zebra_dset, batch_size=5, shuffle = True)