Creating custom dataloader

I am trying to create a dataloader which outputs even and odd digits of MNIST (for multimodal VAE) in the form (0,1);(2,3);(4,5);(6,7);(8,9). I have written down a dataloader which outputs the same digits: (0,0);(1,1)…(9,9). But I am not able to make the digits even and odd separated by 1 with a constraint that the digit in the first modality is even. Here is the code I have so far:

class JointDataset(torch.utils.data.Dataset):
    def __init__(self, mnist_pt_path_1, mnist_pt_path_2):

        self.mnist_pt_path_1 = mnist_pt_path_1
        self.mnist_pt_path_2 = mnist_pt_path_2
        # Load the pt for MNIST 
        self.mnist_data_1, self.mnist_targets_1 = torch.load(self.mnist_pt_path_1)

        # Load the pt for MNIST 
        self.mnist_data_2, self.mnist_targets_2 = torch.load(self.mnist_pt_path_2)
        self.mnist_target_idx_mapping = self.process_mnist_labels()

    def process_mnist_labels(self):
        numbers_dict = {0: [], 1: [], 2: [], 3:[], 4:[], 5:[], 6:[], 7: [], 8:[], 9:[]}
        for i in range(len(self.mnist_targets_2)):
            mnist_target = self.mnist_targets_2[i].item()
            numbers_dict[mnist_target].append(i)
        return numbers_dict
        
        
    def __len__(self):
        return len(self.mnist_data_1)
        
    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index
        Modality 1: even digits
        Modality 2: odd  digits
        """
        mnist_img_1, mnist_target_1 = self.mnist_data_1[index], int(self.mnist_targets_1[index])


        indices_list = self.mnist_target_idx_mapping[(mnist_target_1)]
        # Randomly pick an index from the indices list
        idx = random.choice(indices_list)

        mnist_img_2    = self.mnist_data_2[idx]
        mnist_target_2 = int(self.mnist_targets_2[idx])
        
        
        return mnist_img_1/255, mnist_img_2/255, mnist_target_1, mnist_target_2

I guess a modulo operation would work in the assignment into numbers_dict. Here is an example using random data:

class JointDataset(torch.utils.data.Dataset):
    def __init__(self):

        # Create random dataset and encode data via target
        self.mnist_targets_1 = torch.randint(0, 10, (100,))
        self.mnist_data_1 = self.mnist_targets_1 + 0.1

        self.mnist_targets_2 = torch.randint(0, 10, (100,))
        self.mnist_data_2 = self.mnist_targets_2 + 0.2
        self.mnist_target_idx_mapping = self.process_mnist_labels()

    def process_mnist_labels(self):
        numbers_dict = {0: [], 1: [], 2: [], 3:[], 4:[], 5:[], 6:[], 7: [], 8:[], 9:[]}
        for i in range(len(self.mnist_targets_2)):
            mnist_target = self.mnist_targets_2[i].item()
            numbers_dict[(mnist_target-1)%10].append(i)
        return numbers_dict
        
        
    def __len__(self):
        return len(self.mnist_data_1)
        
    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index
        Modality 1: even digits
        Modality 2: odd  digits
        """
        mnist_img_1, mnist_target_1 = self.mnist_data_1[index], int(self.mnist_targets_1[index])


        indices_list = self.mnist_target_idx_mapping[(mnist_target_1)]
        # Randomly pick an index from the indices list
        idx = random.choice(indices_list)

        mnist_img_2    = self.mnist_data_2[idx]
        mnist_target_2 = int(self.mnist_targets_2[idx])
        
        
        return mnist_img_1, mnist_img_2, mnist_target_1, mnist_target_2
    
dataset = JointDataset()
for data in dataset:
    print(data)

Output:

(tensor(9.1000), tensor(0.2000), 9, 0)
(tensor(9.1000), tensor(0.2000), 9, 0)
(tensor(0.1000), tensor(1.2000), 0, 1)
(tensor(6.1000), tensor(7.2000), 6, 7)
(tensor(6.1000), tensor(7.2000), 6, 7)
(tensor(5.1000), tensor(6.2000), 5, 6)
(tensor(3.1000), tensor(4.2000), 3, 4)
(tensor(3.1000), tensor(4.2000), 3, 4)
(tensor(0.1000), tensor(1.2000), 0, 1)
(tensor(1.1000), tensor(2.2000), 1, 2)
...

Thank you for your response. But I need only even digits in the first modality and odd digits in the second modality. Your code yields an output (n,n+1) but I would like (2n,2n+1). So, for example (0,1) is allowed but (1,2) is not allowed

Oh sorry, I missed this requirement. So it seems you would like to disallow the “odd” mapping, such as (1, 2) and remove these pairs? If so, could you check if this would work:

class JointDataset(torch.utils.data.Dataset):
    def __init__(self):

        # Create random dataset and encode data via target
        self.mnist_targets_1 = torch.randint(0, 10, (100,))
        # remove invalid targets
        self.mnist_targets_1 = self.mnist_targets_1[self.mnist_targets_1%2==0]
        self.mnist_data_1 = self.mnist_targets_1 + 0.1
        

        self.mnist_targets_2 = torch.randint(0, 10, (100,))
        self.mnist_data_2 = self.mnist_targets_2 + 0.2
        self.mnist_target_idx_mapping = self.process_mnist_labels()

    def process_mnist_labels(self):
        numbers_dict = {0: [], 2: [], 4:[], 6:[], 8:[]}
        for i in range(len(self.mnist_targets_2)):
            mnist_target = self.mnist_targets_2[i].item()
            if mnist_target % 2 != 0:
                numbers_dict[(mnist_target-1)%10].append(i)
        return numbers_dict
        
        
    def __len__(self):
        return len(self.mnist_data_1)
        
    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index
        Modality 1: even digits
        Modality 2: odd  digits
        """
        mnist_img_1, mnist_target_1 = self.mnist_data_1[index], int(self.mnist_targets_1[index])


        indices_list = self.mnist_target_idx_mapping[(mnist_target_1)]
        # Randomly pick an index from the indices list
        idx = random.choice(indices_list)

        mnist_img_2    = self.mnist_data_2[idx]
        mnist_target_2 = int(self.mnist_targets_2[idx])
        
        
        return mnist_img_1, mnist_img_2, mnist_target_1, mnist_target_2
    
dataset = JointDataset()
for data in dataset:
    print(data)

Output:

(tensor(8.1000), tensor(9.2000), 8, 9)
(tensor(4.1000), tensor(5.2000), 4, 5)
(tensor(8.1000), tensor(9.2000), 8, 9)
(tensor(8.1000), tensor(9.2000), 8, 9)
(tensor(0.1000), tensor(1.2000), 0, 1)
(tensor(8.1000), tensor(9.2000), 8, 9)
...
1 Like

Thank you for your response. It actually doesn’t work. I get a “Keyerror: 5”. The error occurs in the getitem method on the line:

indices_list = self.mnist_target_idx_mapping[(mnist_target_1)]

Did you add this line as well?

        # remove invalid targets
        self.mnist_targets_1 = self.mnist_targets_1[self.mnist_targets_1%2==0]

This would remove the unwanted indices and no KeyError should be raised.

1 Like

Thank you very much. It works now.