Hi, sorry for yet another SVHN grayscale and resize problem permutation. After following lots of advice here on the forum I have a solution that renders some output. However when plotting a sample image it just shows a distorted color image, instead of the expected grayscale house number.
EDIT: I want a SVHN dataset with images in grayscale and size 28 x 28 in order to train the dataset on my MNIST CNN routine. Is this a correct method for fetching the dataset?
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
class GetDataset(datasets.SVHN):
def __init__(self, root, split='train',
transform=None, target_transform=None, download=True):
super(GetDataset, self).__init__(
root, split, transform, target_transform, download)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], int(self.labels[index])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
svhn_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((28, 28)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5), (0.5, 0.5))
transforms.Normalize([0.5], [0.5])
])
def main():
svhn_training_dataset = GetDataset(
root='./data',
transform=svhn_transform
)
svhn_training_set_loader = torch.utils.data.DataLoader(svhn_training_dataset,
batch_size=128,
shuffle=True,
num_workers=1)
test_img, test_lb = next(iter(svhn_training_set_loader))
plt.imshow(test_img[0, 0], cmap='gray')
plt.show()
if __name__ == '__main__':
main()
Summary
This text will be hidden