I am trying to create a dataloader which outputs even and odd digits of MNIST (for multimodal VAE) in the form (0,1);(2,3);(4,5);(6,7);(8,9). I have written down a dataloader which outputs the same digits: (0,0);(1,1)…(9,9). But I am not able to make the digits even and odd separated by 1 with a constraint that the digit in the first modality is even. Here is the code I have so far:
class JointDataset(torch.utils.data.Dataset):
def __init__(self, mnist_pt_path_1, mnist_pt_path_2):
self.mnist_pt_path_1 = mnist_pt_path_1
self.mnist_pt_path_2 = mnist_pt_path_2
# Load the pt for MNIST
self.mnist_data_1, self.mnist_targets_1 = torch.load(self.mnist_pt_path_1)
# Load the pt for MNIST
self.mnist_data_2, self.mnist_targets_2 = torch.load(self.mnist_pt_path_2)
self.mnist_target_idx_mapping = self.process_mnist_labels()
def process_mnist_labels(self):
numbers_dict = {0: [], 1: [], 2: [], 3:[], 4:[], 5:[], 6:[], 7: [], 8:[], 9:[]}
for i in range(len(self.mnist_targets_2)):
mnist_target = self.mnist_targets_2[i].item()
numbers_dict[mnist_target].append(i)
return numbers_dict
def __len__(self):
return len(self.mnist_data_1)
def __getitem__(self, index: int):
"""
Args:
index (int): Index
Modality 1: even digits
Modality 2: odd digits
"""
mnist_img_1, mnist_target_1 = self.mnist_data_1[index], int(self.mnist_targets_1[index])
indices_list = self.mnist_target_idx_mapping[(mnist_target_1)]
# Randomly pick an index from the indices list
idx = random.choice(indices_list)
mnist_img_2 = self.mnist_data_2[idx]
mnist_target_2 = int(self.mnist_targets_2[idx])
return mnist_img_1/255, mnist_img_2/255, mnist_target_1, mnist_target_2