Understanding how to compute shape of input to Linear() layer

Hi,

I’m trying to walk through by hand the size of input at each layer of my CNN network.

class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, 3)
    self.pool1 = nn.MaxPool2d(2)
    self.conv2 = nn.Conv2d(10, 20, 3)
    self.pool2 = nn.MaxPool2d(2)
    self.conv3 = nn.Conv2d(20, 32, 3)
    self.pool3 = nn.MaxPool2d(2)
    self.linear1 = nn.Linear(32, 10)
  
  def forward(self, x):
    x = self.conv1(x)
    x = self.pool1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = self.pool3(x)
    x = F.relu(x)
    x = self.conv3(x)
    x = self.pool3(x)
    x = F.relu(x)
    print(x.shape)
    x = self.linear1(x)
    return F.softmax(x)

I’m using the MNIST dataset as input
Input size: 28x28
Size After conv1 layer: 26x26
Size After Pool1 layer: 13x13
Size After conv2 layer: 11x11
Size After Pool2 layer: 5x5
Size After conv3 layer: 3x3
Size After Pool3 layer: 1x1
So number of channels at input of linear is 32x1x1=32

But I get an error

from torch.autograd import Variable
maxPoolLayer = nn.MaxPool2d(2)
conv2dLayer = nn.Conv2d(1, 10, 3)
mA = Variable(torch.randn(1, 1, 25, 25))
mA.size()
mC = conv2dLayer(mA)
print(mC.size())
mB = maxPoolLayer(mC)
print(mB.size())

Output:
torch.Size([1, 10, 23, 23])
torch.Size([1, 10, 11, 11])

So default conv2d layer will reduce the size on input by 2 for both dimensions and maxpooling will floor-half the input

Yes, that is correct, if your Conv2d has a stride of one a 0 padding.

What you need to do to resolve your problem is x = torch.squeeze(x) just before your call to self.linear1(x). This is because x, before the squeeze, has a shape of B x 32 x 1 x 1 and by squeezing it, the shape will become B x 32 which will be compatible with your Linear layer (B being the batch size).

Also, (this does not change anything), but you use self.pool3(x) 2 times during your forward pass :slight_smile:

1 Like

Thank you for your help!
I tried using it but for some reason it was not working as you have explained. I’m not sure if my calculations in the original post are correct.
I did some search and finally I came across the torch.flatten() function. I had to use it as torch.flatten(x, 1, 3). Please let me know your thoughts on the same.

Also you have said that we use self.pool3(x) 2 times. Any specific reason for this.

Ok, it worked when I tested with torch.squeeze, but yes in this case, torch.flatten(x, 1, 3) does the same thing i.e. it eliminates the dimensions of size 1.

    x = self.conv2(x)
    x = self.pool3(x)
    x = F.relu(x)
    x = self.conv3(x)
    x = self.pool3(x)

It does not change anything, but self.pool3 is used after self.conv2 which means that self.pool2 is never used, that’s all!