I am trying to extract only the male images from the pytorch CelebA dataset. In the document it says to pass the torchvision.datasets.CelebA
command with the target_type
argument. I am confused about how to specify this value to extract only the male images.
For reference the male attribute is 20 and in the .csv file, 1 indicates male and -1 indicates female.
I think you can pass an int
or a list
of ints
to the target_type
so that the created Dataset
returns the desired target values. In your case you would pass 20
to get the gender attribute.
Once this is done, you could collect the targets for all samples and filter them out using your condition (e.g. only “male” targets). Once this is done you could then compute the indices of all “male” samples and pass them to a Subset
to only yield these samples.
Thanks! I seem to be running into a download error for the dataset, is there any way to fix this?
This is the error: RuntimeError: The daily quota of the file img_align_celeba.zip is exceeded and it can't be downloaded. This is a limitation of Google Drive and can only be overcome by trying again later.
It seems this issue is created by the hosting of the dataset itself and you could try to download it using other mirrors.
Ok! I actually downloaded the dataset to my local machine, however, I don’t think I can still use the torchvision.datasets command with the target_type
argument. In this case, how would I filter the images using the attribute?
Here’s how I solved this:
As suggested earlier, I downloaded the dataset from the other mirrors and created a Dataset class object
class MaleFacesDataset(Dataset):
def __init__(self, csv_file, root_dir, transform):
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
#print("GET!")
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
sample = image
sample = Image.fromarray(np.uint8(sample)).convert('RGB')
if self.transform:
sample = self.transform(sample)
return sample
Then I created an instance of this object
male_dataset = MaleFacesDataset(csv_file = 'csv file dir',
root_dir= 'root dir')
As suggested earlier, I then extracted the indices of the male images as follows
landmarks_frame = pd.read_csv('attribute file dir') # read the csv file with attributes
male_images_list = landmarks_frame.iloc[:,20] # extract the column with gender attributes
male_images_list = male_images_list[male_images_list == 1] # select only the indices for male (attribute val = 1)
male_images_index = male_images_list.index # corresponding image indices
And finally, I created a subset of my original dataset using
Subset(male_dataset, male_images_index)