ResNet18 output shape depending on number of in_channels

Hi everyone :slight_smile:

I am using the ResNet18 for a Deep Learning project on CIFAR10. However, I want to pass the grayscale version of the CIFAR10 images to the ResNet18. When I change the expected number of input channels and change the number of classes from 1000 to 10 I get output shapes that I don’t understand. Here is my code:

from torchsummary import summary
import torchvision.models as PyTorchModels
r = PyTorchModels.resnet18()
num_in_channels = 1
num_out_channels = r.conv1.out_channels
size_kernel = r.conv1.kernel_size
num_in_features = r.fc.in_features
classes = 10
r.conv1 = nn.Conv2d(in_channels=num_in_channels, out_channels=num_out_channels, kernel_size=size_kernel)
r.fc = nn.Linear(in_features=num_in_features, out_features=classes)
model = r
summary(model.cuda(), (1, 32, 32))

It produces the following output (only the first couple of layers are shown):

Why is the first output of shape [-1, 64, 26, 26]? Where does the 26 come from? I would expect it to be of shape [-1, 64, 16, 16]. More generally, how does the number of in_channels (e.g. RGB vs. grayscale) effect the output shape?

Any help is very much appreciated!

All the best

For the first question:

Run print(model), you should see the first layer is (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False). You modify your first layer without padding and stride, that why it became 26.

For the second question:

Because this is CNN models, so the number of in_channels will not affect to your output shape, the number of class do that.
For example: with ImageNet, the output of ResNet18 is 1x1000, with CIFAR10, it is 1x10

1 Like

@Toby Thank you so much! That solved it. I totally forgot to add the padding, stride, and bias :see_no_evil:

Edit: a quick follow up question: does PyTorch automatically adjust the depth of the filters?

1 Like

Next time, you should open a new question :wink:
To answer your question: no, it doesn’t. You can take a look at the source code