How to get MNIST data from torchvision with three channels for some pretrained model like VGG?

I think this should be encountered by many people, and is not difficult, but when I want to use the MNIST data from torchvision with a model pretrained from torchvision.models, such as VGG, I got the error:

Given groups=1, weight of size [64, 3, 3, 3], expected input[64, 1, 28, 28] to have 3 channels, but got 1 channels instead

It seems that the model requires 3 channel inputs, but the data from torchvision is 1 channel, for MNIST.

2 Likes

One ugly method is to add transforms.Grayscale(3) in the transform of the dataloader.

1 Like

Another way would be to add a nn.Conv2d layer at the beginning of the pretrained model which gives you 3 output channels. However, you would need to train this layer.

2 Likes

Thanks @ptrblck. Iā€™m not sure how other framework like tensorflow treat this issue, but it seems to me some inconsistency of the package.

Hi, can you please explain how to do this for GoogleNet?

You could either replace the first layer in GoogleNet with e.g. an nn.Sequential container containing the new conv layer as well as the pretrained one:

model = models.googlenet()
model.conv1 = nn.Sequential(
    nn.Conv2d(1, 3, 1),
    model.conv1)

x = torch.randn(1, 1, 224, 224)
out = model(x)

or just use an extra layer before the model:

conv = nn.Conv2d(1, 3, 1)
model = models.googlenet()

x = torch.randn(1, 1, 224, 224)
out = conv(x)
out = model(out)