Hi,
I am looking to build a Gan model using pytorch which requires me to load images of different classes seprerately. I have used the torchvision.datasets.imagefolder
to load my dataset. I have images of horses and zebras. Now how can I load the images from only one class of the data. That is, I want to be able to specify to my data loader whether I want images of horses or zebras?
1 Like
You could extract all indices for the desired class (e.g. horses) and create a torch.utils.data.Subset
using these indices and the dataset before passing it to the DataLoader
.
Alternatively, you could also pass these indices to a RandomSubsetSampler
and assign it to the DataLoader
.
Hi,
Thank you for the solution @ptrblck . Using torch.utils.data.Subset
did the trick for me.
Here is how I managed to do it.
import torchvision.datasets as dset
from torch.utils.data import DataLoader, Subset
gan_dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
# the ImageFolder class has the following variables which can be used to find
# the range of indices for Subset:
print(gan_dataset.class_to_idx)
# gan_dataset.imgs is a list of tuples of (file_path, class_index) for all items in the dataset
print(gan_dataset.imgs)
horse_dset = Subset(gan_dataset, range(260, 1327))
zebra_dset = Subset(gan_dataset, range(1327, 2670))
horse_loader = DataLoader(horse_dset, batch_size=5, shuffle = True)
zebra_loader = DataLoader(zebra_dset, batch_size=5, shuffle = True)