Understanding how to compute shape of input to Linear() layer


I’m trying to walk through by hand the size of input at each layer of my CNN network.

class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, 3)
    self.pool1 = nn.MaxPool2d(2)
    self.conv2 = nn.Conv2d(10, 20, 3)
    self.pool2 = nn.MaxPool2d(2)
    self.conv3 = nn.Conv2d(20, 32, 3)
    self.pool3 = nn.MaxPool2d(2)
    self.linear1 = nn.Linear(32, 10)
  def forward(self, x):
    x = self.conv1(x)
    x = self.pool1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = self.pool3(x)
    x = F.relu(x)
    x = self.conv3(x)
    x = self.pool3(x)
    x = F.relu(x)
    x = self.linear1(x)
    return F.softmax(x)

I’m using the MNIST dataset as input
Input size: 28x28
Size After conv1 layer: 26x26
Size After Pool1 layer: 13x13
Size After conv2 layer: 11x11
Size After Pool2 layer: 5x5
Size After conv3 layer: 3x3
Size After Pool3 layer: 1x1
So number of channels at input of linear is 32x1x1=32

But I get an error

from torch.autograd import Variable
maxPoolLayer = nn.MaxPool2d(2)
conv2dLayer = nn.Conv2d(1, 10, 3)
mA = Variable(torch.randn(1, 1, 25, 25))
mC = conv2dLayer(mA)
mB = maxPoolLayer(mC)

torch.Size([1, 10, 23, 23])
torch.Size([1, 10, 11, 11])

So default conv2d layer will reduce the size on input by 2 for both dimensions and maxpooling will floor-half the input

Yes, that is correct, if your Conv2d has a stride of one a 0 padding.

What you need to do to resolve your problem is x = torch.squeeze(x) just before your call to self.linear1(x). This is because x, before the squeeze, has a shape of B x 32 x 1 x 1 and by squeezing it, the shape will become B x 32 which will be compatible with your Linear layer (B being the batch size).

Also, (this does not change anything), but you use self.pool3(x) 2 times during your forward pass :slight_smile:

1 Like

Thank you for your help!
I tried using it but for some reason it was not working as you have explained. I’m not sure if my calculations in the original post are correct.
I did some search and finally I came across the torch.flatten() function. I had to use it as torch.flatten(x, 1, 3). Please let me know your thoughts on the same.

Also you have said that we use self.pool3(x) 2 times. Any specific reason for this.

Ok, it worked when I tested with torch.squeeze, but yes in this case, torch.flatten(x, 1, 3) does the same thing i.e. it eliminates the dimensions of size 1.

    x = self.conv2(x)
    x = self.pool3(x)
    x = F.relu(x)
    x = self.conv3(x)
    x = self.pool3(x)

It does not change anything, but self.pool3 is used after self.conv2 which means that self.pool2 is never used, that’s all!