Fine tuning SqueezeNet for MNIST dataset


(Trần Công) #1

I want to fine tune SqueezeNet model from torchvision/models/squeezenet.py for MNIST dataset.
What should i do? Thanks!


#2

You could use the MNIST example as your template and just change a few lines of code.

Since the MNIST data comes as grayscale images in the resolution 28x28, you would need to repeat the channel to simulate an RGB image as well as resizing it to 224x224 to fit your SqueezeNet.

Here is a small example of the necessary changes:

model = models.squeezenet1_1(pretrained=False)
    
dataset = datasets.MNIST(
    root='PATH',
    train=True,
    transform=transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.expand(3, -1, -1))
    ])
)

loader = DataLoader(
    dataset,
    batch_size=10,
    shuffle=True,
    num_workers=2
)


data, target = next(iter(loader))
output = model(data)
print(output.shape)