Looks pretty good as a starter.
If I download the .gz MNIST test images file and SVHN test images, extract them, put them in data/mnist and mnist/svhn/test folders respectively it should create MyDataset as the new dataset with 100 images from MNIST and the next 10 images from svhn, right?
I switched cifar with svhn, and resized the svhn to 32 as well (I’ll probably use a net that was trained on cifar). Any other changes are required? (code below)
Currently it appears like it can’t find the dataset
RuntimeError: Dataset not found. You can use download=True to download it
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def init(self, mnist_transform=None, svhn_transform=None):
mnist = datasets.MNIST(
root=’./data/mnist’,
)
svhn = datasets.SVHN(
root=’./data/svhn/test’,
)
self.mnist_len = 100
self.svhn_len = 10
rand_idx = torch.randperm(len(mnist.data))[:self.mnist_len]
self.mnist_data = mnist.data[rand_idx]
rand_idx = torch.randperm(len(svhn.data))[:self.svhn_len]
self.svhn_data = svhn.data[rand_idx]
self.mnist_transform = mnist_transform
self.svhn_transform = svhn_transform
def __getitem__(self, index):
if index < self.mnist_len:
x = self.mnist_data[index]
if self.mnist_transform:
x = self.mnist_transform(x)
print('Returning MNIST sample at index {}'.format(index))
return x
else:
index = index - self.mnist_len
x = self.svhn_data[index]
if self.svhn_transform:
x = self.svhn_transform(x)
print('Returning SVHN data at index {}'.format(index))
return x
def __len__(self):
return self.mnist_len + self.svhn_len
mnist_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((32, 32)),
transforms.Grayscale(num_output_channels=3),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
svhn_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToPILImage(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = MyDataset(
mnist_transform=mnist_transform,
svhn_transform=svhn_transform
)