I’m trying on extend the PyTorch MNIST dataset to create a custom dataset of labeled pairs (so a tuple of (img1, img2, new_label)). I was thinking that the __getitem__
function was the most appropriate entry point for this task. However, I’m getting weird results regarding the shape of the output. If I create a barebones extension (shown below, first snippet) I get the expected output shape of torch.Size([10, 1, 28, 28])
for a batch size of 10. However, when I try access the data by index in the __getitem__
function (second snippet below), it drops the first dimension of the image such that the output shape is torch.Size([10, 28, 28])
. More importantly, when try to access the “third” dimension of the image in the __getitem__
function I get an error because it does not exist. What am I doing wrong?
Snippet one, barebones extension of MNIST dataset (just passing the data):
class TestSubLoader(torchvision.datasets.MNIST):
'''
Subloader to extent MNIST dataset
'''
def __init__(self, *args, **kwargs):
super(TestSubLoader, self).__init__(*args, **kwargs)
def __len__(self):
return len(self.data)
Which outputs:
__main__ sub_image: torch.Size([10, 1, 28, 28])
__main__ pri_image: torch.Size([10, 1, 28, 28])
Snippet two, adding in __getitem__
to simply returning the batched data:
class TestSubLoader(torchvision.datasets.MNIST):
'''
Subloader to extent MNIST dataset
'''
def __init__(self, *args, **kwargs):
super(TestSubLoader, self).__init__(*args, **kwargs)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.targets[idx] # <-- self.data[idx, :, :, :] = IndexError
Which outputs:
__main__ sub_image: torch.Size([10, 28, 28])
__main__ pri_image: torch.Size([10, 1, 28, 28])
Here is the code I use to run the test to get the outputs above:
mnist_set_subloaded = TestSubLoader(root=data_dir,train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor()]))
mnist_set_primary = torchvision.datasets.MNIST(root=data_dir,train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor()]))
subloaded_loader = torch.utils.data.DataLoader(mnist_set_subloaded,
shuffle=True,
num_workers=2,
batch_size=10)
primary_loader = torch.utils.data.DataLoader(mnist_set_primary,
shuffle=True,
num_workers=2,
batch_size=10)
for i, ((sub_image,sub_target),(pri_image,pri_target)) \
in enumerate(zip(subloaded_loader,primary_loader)):
if i%10 == 0:
print(f'__main__ sub_image: {sub_image.shape}')
print(f'__main__ pri_image: {pri_image.shape}')
break
First, is this a good approach to create image pairs (obviously, not implemented above…). Second, regarding the main topic, what am I missing – do I have to manually reshape the self.data
to maintain the structure?
Thanks for your help.