I am trying to subset particular class (in particular, samples from labels 0, 4, 8) samples from the MNIST-M Dataset (source). Since torchvision does not have a predefined library function to load the MNIST-M Dataset, I am using the following custom dataset class function:
class MNIST_M(torch.utils.data.Dataset):
def __init__(self, root, train, transform=None):
self.train = train
self.transform = transform
if train:
self.image_dir = os.path.join(root, 'mnist_m_train')
self.labels_file = os.path.join(root, "mnist_m_train_labels.txt")
else:
self.image_dir = os.path.join(root, 'mnist_m_test')
self.labels_file = os.path.join(root, "mnist_m_test_labels.txt")
with open(self.labels_file, "r") as fp:
content = fp.readlines()
self.mapping = list(map(lambda x: (x[0], int(x[1])), [c.strip().split() for c in content]))
def __len__(self):
return len(self.mapping)
def __getitem__(self, idx):
image, labels = self.mapping[idx]
image = os.path.join(self.image_dir, image)
image = self.transform(Image.open(image).convert('RGB'))
return image, labels
def _load_data(self):
data = read_image_file(self.image_dir)
targets = read_label_file(self.labels_file)
return data, targets
To create a subset of the dataset, I am using the following code which I have ported from the similar case of subsetting MNIST Dataset (subsetting MNIST reference):
mnist_train_ds_modded = datasets.MNIST(root_dir, download=True, train=True, transform=source_transform)
mnistm_train_ds_modded = MNIST_M(root=root_dir, train=True,
transform=transforms.Compose([
transforms.Scale(imageSize),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
mnistm_train_ds_modded.labels = torch.tensor(mnistm_train_ds_modded.labels)
mnistm_train_indexes_0 = 1*(mnistm_train_ds_modded.labels == 0).nonzero().flatten().tolist()
mnistm_train_indexes_4 = 1*(mnistm_train_ds_modded.labels == 4).nonzero().flatten().tolist()
mnistm_train_indexes_8 = 1*(mnistm_train_ds_modded.labels == 8).nonzero().flatten().tolist()
mnistm_train_modded_idx = mnistm_train_indexes_0 + mnistm_train_indexes_4 + mnistm_train_indexes_8
mnistm_train_ds_modded.labels = mnistm_train_ds_modded.labels[mnistm_train_modded_idx]
mnistm_train_ds_modded.data = mnistm_train_ds_modded.data[mnistm_train_modded_idx]
Clearly, the MNIST-M dataset class written above does not have any attributes called data and labels. So the above code will not work as it does in the case of MNIST.
I went through the source code of MNIST to define the class attributes data and labels but I am unable to do the same for png files (MNIST-M has png files). Kindly help me define the class attributes so that I can subset the dataset.