This reference helped me get in the right direction. It seems I need to convert to PIL and then be sure to include the transforms. So I have the code below. Is this accurate?
class TestSubLoader(torchvision.datasets.MNIST):
'''
Subloader to extent MNIST dataset
'''
def __init__(self, *args, **kwargs):
super(TestSubLoader, self).__init__(*args, **kwargs)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data, target = self.data[idx], self.targets[idx]
# transform the data to PIL to get the right shape
data = transforms.ToPILImage()(data)
if self.transform is not None:
data = self.transform(data)
if self.target_transform is not None:
target = self.target_transform(target)
return data, target
On to my other concern, is __getitem__
a good place to collect image pairs? Any thoughts are welcome.