Extended Dataset __getitem__ dimensions

I’m trying on extend the PyTorch MNIST dataset to create a custom dataset of labeled pairs (so a tuple of (img1, img2, new_label)). I was thinking that the __getitem__ function was the most appropriate entry point for this task. However, I’m getting weird results regarding the shape of the output. If I create a barebones extension (shown below, first snippet) I get the expected output shape of torch.Size([10, 1, 28, 28]) for a batch size of 10. However, when I try access the data by index in the __getitem__ function (second snippet below), it drops the first dimension of the image such that the output shape is torch.Size([10, 28, 28]). More importantly, when try to access the “third” dimension of the image in the __getitem__ function I get an error because it does not exist. What am I doing wrong?

Snippet one, barebones extension of MNIST dataset (just passing the data):

class TestSubLoader(torchvision.datasets.MNIST):
    '''
    Subloader to extent MNIST dataset
    '''
    def __init__(self, *args, **kwargs):
        
        super(TestSubLoader, self).__init__(*args, **kwargs)
        
    def __len__(self):
        return len(self.data)
    

Which outputs:

__main__ sub_image: torch.Size([10, 1, 28, 28])
__main__ pri_image: torch.Size([10, 1, 28, 28])

Snippet two, adding in __getitem__ to simply returning the batched data:

class TestSubLoader(torchvision.datasets.MNIST):
    '''
    Subloader to extent MNIST dataset
    '''
    def __init__(self, *args, **kwargs):
        
        super(TestSubLoader, self).__init__(*args, **kwargs)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx] # <-- self.data[idx, :, :, :] = IndexError

Which outputs:

__main__ sub_image: torch.Size([10, 28, 28])
__main__ pri_image: torch.Size([10, 1, 28, 28])

Here is the code I use to run the test to get the outputs above:

mnist_set_subloaded = TestSubLoader(root=data_dir,train=True, 
                                   download=True,
                                   transform=transforms.Compose([transforms.ToTensor()])) 

mnist_set_primary = torchvision.datasets.MNIST(root=data_dir,train=True, 
                                   download=True,
                                   transform=transforms.Compose([transforms.ToTensor()]))
                                    

subloaded_loader = torch.utils.data.DataLoader(mnist_set_subloaded,
                                               shuffle=True,
                                               num_workers=2,
                                               batch_size=10)

primary_loader = torch.utils.data.DataLoader(mnist_set_primary,
                                               shuffle=True,
                                               num_workers=2,
                                               batch_size=10)

for i, ((sub_image,sub_target),(pri_image,pri_target)) \
        in enumerate(zip(subloaded_loader,primary_loader)):

    if i%10 == 0:
        print(f'__main__ sub_image: {sub_image.shape}') 
        print(f'__main__ pri_image: {pri_image.shape}') 
        break

First, is this a good approach to create image pairs (obviously, not implemented above…). Second, regarding the main topic, what am I missing – do I have to manually reshape the self.data to maintain the structure?

Thanks for your help.

This reference helped me get in the right direction. It seems I need to convert to PIL and then be sure to include the transforms. So I have the code below. Is this accurate?

class TestSubLoader(torchvision.datasets.MNIST):
    '''
    Subloader to extent MNIST dataset
    '''
    def __init__(self, *args, **kwargs):
        
        super(TestSubLoader, self).__init__(*args, **kwargs)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data, target = self.data[idx], self.targets[idx]
        # transform the data to PIL to get the right shape
        data = transforms.ToPILImage()(data)

        if self.transform is not None:
            data = self.transform(data)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return data, target

On to my other concern, is __getitem__ a good place to collect image pairs? Any thoughts are welcome.

Yes, this would basically recreate the original MNIST __getitem__ method as seen here.

Alternatively, you could also directly call into the __getitem__ from the parent class via:

class TestSubLoader(torchvision.datasets.MNIST):
    '''
    Subloader to extent MNIST dataset
    '''
    def __init__(self, *args, **kwargs):
        super(TestSubLoader, self).__init__(*args, **kwargs)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return super().__getitem__(idx)

which would yield the same output shapes as the original MNIST dataset.

Yes, I think it’s the best place to add this logic to the data loading.
Let me know, if it works for your use case or if you get stuck.