Size mismatch error in CNN FC layer

Hello everyone!

Hoping someone can help explain how I can calculate the input dimensionality into the first fully connected layer of the following CNN architecture I’m tinkering with… see the question marks in the code:

class AlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2)) #
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 96, kernel_size=3, padding=1),
            nn.BatchNorm2d(96),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(96, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 1))
        
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(?????????, 2048),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 1024),
            nn.ReLU())
        self.fc2= nn.Sequential(
            nn.Linear(1024, d.num_classes))

If you can, please show how you arrived at the input dimensionality, because I have also tried using the formula (W - F +2P)/S + 1, but I know that this is supposed to be done recursively depending on the number of previous convolutional layers.
However, I was unable to properly apply the formula…

Thanks so much,
John

This code adds the annotations:

class AlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=1), # (224 - 5 + 2)/1 + 1 = 222
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 3, stride = 2)) # (222 + 0 - 1 * (3-1)-1)/2 + 1 = floor(110.5) = 110
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1), # 110
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2)) # 55
        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 96, kernel_size=3, padding=1), # 55
            nn.BatchNorm2d(96),
            nn.ReLU())
        self.layer4 = nn.Sequential(
            nn.Conv2d(96, 64, kernel_size=3, padding=1), # 55
            nn.BatchNorm2d(64),
            nn.ReLU())
        self.layer5 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1), # 55
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 1)) # (55 + 0 - 1 * (2-1)-1)/1 + 1 = 54
        
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(1, 2048),
            nn.ReLU())
        self.fc1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(2048, 1024),
            nn.ReLU())
        self.fc2= nn.Sequential(
            nn.Linear(1024, 10))
        
    def forward(self, x):
        print(x.shape) # torch.Size([2, 1, 224, 224])
        x = self.layer1(x)
        print(x.shape) # torch.Size([2, 32, 110, 110])
        x = self.layer2(x)
        print(x.shape) # torch.Size([2, 64, 55, 55])
        x = self.layer3(x)
        print(x.shape) # torch.Size([2, 96, 55, 55])
        x = self.layer4(x)
        print(x.shape) # torch.Size([2, 64, 55, 55])
        x = self.layer5(x)
        print(x.shape) # torch.Size([2, 32, 54, 54])
        x = x.view(x.size(0), -1)
        print(x.shape) # torch.Size([2, 93312])
        x = self.fc(x)
        print(x.shape)
        x = self.fc1(x)
        print(x.shape)
        x = self.fc2(x)
        print(x.shape)
        return x

model = AlexNet()
x = torch.randn(2, 1, 224, 224)
out= model(x)

using the provided formulas from the docs.
Your model contains some trivial layers such as:

nn.Conv2d(32, 64, kernel_size=3, padding=1)

which will keep the spatial size equal and

nn.MaxPool2d(kernel_size = 2, stride = 2))

which will reduce the spatial size by 2x.

However, for other layers I have just applied the formulas and added them as a comment.
Also, an easy way to avoid calculating the input features is to define the in_features of the linear layer with any value and to print the activation shape in the forward as is also done in my code.

Here is a simple function you can use, based on the docs, to calculate the output size from any convolution or pooling layer:

import math
def calc_conv_size(length, kernel_size, stride=1, padding=0, dilation=1):
    return math.floor((length + 2*padding - dilation*(kernel_size-1)-1)/stride+1)

For example:

out1=calc_conv_size(224, 5, 1, 1) #Conv2d
print(out1)
out2=calc_conv_size(out1, 3, 2) #MaxPool2d
print(out2)

Ha, thanks for sharing as I just typed it in manually!

Also, @johnvalen1 if you don’t want to calculate the input features at all, you could use the nn.Lazy* modules, which will set the in_features shape based on the first input.

Hello, thank you so much. That’s a good helper function to have laying around and just call up when I need to visualize. From what I understand, calc_conv_size gives you the dimensionality of the current conv layer, which is passed as input dimensionality to the subsequent conv layer, right?

Hello,
Thanks so much for that code explanation with comments and shown calculations. It really helped me to understand. However, I don’t understand one thing: how did you calculate the 224 for the value of W in the first convolutional layer?

Also, I tried to run your code, but the input dimensionality is not (1, 2048), but turned out to be (800, 2048) according to the error thrown:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x800 and 1x2048)

So, from this I knew that the input to the first FC layer is 800 x 2048, but I do wish to understand the math, rather than relying on the error output.

Do you mean the shape of x before being passed to the first conv layer?
If so, then it’s defined by the input shape via:

x = torch.randn(2, 1, 224, 224)

before being passed to the model:

out= model(x)

I assume you have changed the input shape then?
If so, you could recalculate or print the output shapes based on my annotations for an input spatial size of 224x224 using your new input shape.
@J_Johnson 's util. function might come handy to do so.

1 Like

@johnvalen1
The above function just tells you what size you can expect out of a layer with a given input size.

Unlike other layer types, conv and pooling layers can take nearly any size. They act upon an input to both change the output values and(in most cases) the size. Whereas other layer types require a strict input size. So you just need to know or define what size you’re getting out of your final conv/pooling neural network(CNN) before those “distilled images” pass on to your other layers(i.e. Linear, GRU, Transformer, etc.).

If you’re using mixed size images, you may want to apply an AdaptiveAvgPooling layer to standardize the sizes. I prefer to place these somewhere in the middle of the CNN. If at the beginning, you can lose critical texture information. If at the end, you can lose macro information.

We can work our way backward from the first fully connected layer(fc1). Let’s say we want the size out of the CNN to be 1.

Let’s make an inverse function to iteratively determine the size in the previous layers in order to get this size 1 out.

def inv_conv_size(out_length, kernel_size, stride=1, padding=0, dilation=1):
    return (out_length - 1)*stride + 1 + dilation*(kernel_size - 1) -2*padding

If you work backward through each layer using the above inverse function, you should be able to calculate an input size. Say you want to place the AdaptiveAvgPool after layer3. Running the above function starting with 1, you should get 2 for the input size before layer4. So you could define it as nn.AdaptiveAvgPool2d(2), if placing there.

1 Like

Hi, I see, so you decided that the ConvNet should take in images of dimensions 224 by 224. Mine was taking in images of size 28*28, so I was confused what the 224 represented. In my case, W = 28 and then I would use your annotations.
Perhaps that’s why (1, 2048) didn’t work for me, when it came to the first FC layer. The error thrown wanted an input dimensionality of (800, 2048) to that FC layer. I would attribute this to the fact that my initial images are of size (28, 28, 1).

My mistake for failing to specify my image dimensions in the original post! Everything is clear now, though.

Hi, thank you for that! You have provided both a forward and backward method for me to count layer dimensions, which is invaluable.

I also didn’t know that about AdaptiveAvgPooling (namely, that they should be used towards the middle of the CNN).

Question for you: why did you define an inverse conv size function? Is there generally utility in iterating backwards, for a desired output size? I haven’t yet seen any use case for this. Or was this done for illustration purposes that it is indeed possible?

Thank you!

Glad it helped!

Regarding the inverse function, it can be useful if you need to guarantee a particular size at some point in your CNN model, with the condition that the input into the Linear layers are of a certain size.

Normally though, someone developing a model will just use padding='same' so that all Conv layers give the same size out, and then use MaxPooling of kernel_size=2, stride=2 after each Conv or block of Conv layers. In that case, every time there is a MaxPool layer, the size gets cut in half. So it makes it much easier to calculate without a function. Additionally, most will handle all image resizing before the data enters the model, so the input data is uniform. However, in some cases where textures play an important role, such as in classification problems, you may want the type of setup I described above with an AvgPooling layer.