# Understanding how to compute shape of input to Linear() layer

Hi,

I’m trying to walk through by hand the size of input at each layer of my CNN network.

``````class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 10, 3)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(10, 20, 3)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(20, 32, 3)
self.pool3 = nn.MaxPool2d(2)
self.linear1 = nn.Linear(32, 10)

def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.pool3(x)
x = F.relu(x)
x = self.conv3(x)
x = self.pool3(x)
x = F.relu(x)
print(x.shape)
x = self.linear1(x)
return F.softmax(x)
``````

I’m using the MNIST dataset as input
Input size: 28x28
Size After conv1 layer: 26x26
Size After Pool1 layer: 13x13
Size After conv2 layer: 11x11
Size After Pool2 layer: 5x5
Size After conv3 layer: 3x3
Size After Pool3 layer: 1x1
So number of channels at input of linear is 32x1x1=32

But I get an error

``````from torch.autograd import Variable
maxPoolLayer = nn.MaxPool2d(2)
conv2dLayer = nn.Conv2d(1, 10, 3)
mA = Variable(torch.randn(1, 1, 25, 25))
mA.size()
mC = conv2dLayer(mA)
print(mC.size())
mB = maxPoolLayer(mC)
print(mB.size())
``````

Output:
torch.Size([1, 10, 23, 23])
torch.Size([1, 10, 11, 11])

So default conv2d layer will reduce the size on input by 2 for both dimensions and maxpooling will floor-half the input

Yes, that is correct, if your `Conv2d` has a stride of one a 0 padding.

What you need to do to resolve your problem is `x = torch.squeeze(x)` just before your call to `self.linear1(x)`. This is because `x`, before the squeeze, has a shape of `B x 32 x 1 x 1` and by squeezing it, the shape will become `B x 32` which will be compatible with your `Linear` layer (`B` being the batch size).

Also, (this does not change anything), but you use `self.pool3(x)` 2 times during your forward pass

1 Like

I tried using it but for some reason it was not working as you have explained. I’m not sure if my calculations in the original post are correct.
I did some search and finally I came across the `torch.flatten()` function. I had to use it as `torch.flatten(x, 1, 3)`. Please let me know your thoughts on the same.

Also you have said that we use `self.pool3(x)` 2 times. Any specific reason for this.

Ok, it worked when I tested with `torch.squeeze`, but yes in this case, `torch.flatten(x, 1, 3)` does the same thing i.e. it eliminates the dimensions of size 1.

``````    x = self.conv2(x)
x = self.pool3(x)
x = F.relu(x)
x = self.conv3(x)
x = self.pool3(x)
``````

It does not change anything, but `self.pool3` is used after `self.conv2` which means that `self.pool2` is never used, that’s all!