I am looking to build a Gan model using pytorch which requires me to load images of different classes seprerately. I have used the
torchvision.datasets.imagefolder to load my dataset. I have images of horses and zebras. Now how can I load the images from only one class of the data. That is, I want to be able to specify to my data loader whether I want images of horses or zebras?
You could extract all indices for the desired class (e.g. horses) and create a
torch.utils.data.Subset using these indices and the dataset before passing it to the
Alternatively, you could also pass these indices to a
RandomSubsetSampler and assign it to the
Thank you for the solution @ptrblck . Using
torch.utils.data.Subset did the trick for me.
Here is how I managed to do it.
import torchvision.datasets as dset from torch.utils.data import DataLoader, Subset gan_dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([ transforms.Resize(256), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])) # the ImageFolder class has the following variables which can be used to find # the range of indices for Subset: print(gan_dataset.class_to_idx) # gan_dataset.imgs is a list of tuples of (file_path, class_index) for all items in the dataset print(gan_dataset.imgs) horse_dset = Subset(gan_dataset, range(260, 1327)) zebra_dset = Subset(gan_dataset, range(1327, 2670)) horse_loader = DataLoader(horse_dset, batch_size=5, shuffle = True) zebra_loader = DataLoader(zebra_dset, batch_size=5, shuffle = True)