nikhil6041
(Nikhil Kumar Ghanghor)
September 10, 2021, 3:50am
1
I came across this answer https://stackoverflow.com/questions/63975130/how-to-get-only-specific-classes-from-pytorchs-fashionmnist-dataset which samples images from a subset of classes using ImageFolder but I want to achieve the same thing using a custom Dataset class, any suggestions on how can i do it?
I have attached my Custom Dataset class for the references.
class CustomDataset(Dataset):
def __init__(self, root_dir, transforms=None):
self.root_dir = root_dir
self.transforms = transforms
self.img_list = sorted(glob.glob(root_dir + "/*/*"))
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.img_list[idx]
image = io.imread(img_name)
if self.transforms:
sample = self.transforms(image)
return image
ptrblck
September 10, 2021, 4:59am
2
You won’t be able to use the same approach, as your custom Dataset
doesn’t return any targets and you thus cannot use them to filter out the samples.
nikhil6041
(Nikhil Kumar Ghanghor)
September 10, 2021, 5:15am
3
nikhil6041:
class CustomDataset(Dataset):
def __init__(self, root_dir, transforms=None):
self.root_dir = root_dir
self.transforms = transforms
self.img_list = sorted(glob.glob(root_dir + "/*/*"))
def __len__(self):
return len(self.img_list)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.img_list[idx]
image = io.imread(img_name)
if self.transforms:
sample = self.transforms(image)
return image
@ptrblck
what if i change my custom dataset to return the labels in that case?
ptrblck
September 10, 2021, 6:09am
4
In case you are creating the targets in the __init__
method or are passing them to the Datset
, i.e. they are pre-calculated and not lazily computed, you could apply the same approach as in your posted link by filtering out the .targets
as well as the corresponding .img_list
or by using a Subset
.