How the first paramter of nn.Linear is selected?

I am trying to understand how to build a NN in pytorch. I am playing with the MNIST dataset. I am not able to understand that how the first parameter in nn.Linear is set to 9216?

class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.conv1 = nn.Conv2d(1, 32, 3, 1)
      self.conv2 = nn.Conv2d(32, 64, 3, 1)
      self.dropout1 = nn.Dropout2d(0.25)
      self.dropout2 = nn.Dropout2d(0.5)
      self.fc1 = nn.Linear(9216, 128)
      self.fc2 = nn.Linear(128, 10)
my_nn = Net()

The first parameter also known as in_parameters basically means that the no. of parameters the current linear layers expects from the previous layer. You can refer to the pytorch docs for better explaination.
I would suggest to take the pytorch 60 minute blitz Deep Learning with PyTorch: A 60 Minute Blitz — PyTorch Tutorials 1.8.1+cu102 documentation .
For your current issue you can refer to this Build the Neural Network — PyTorch Tutorials 1.8.1+cu102 documentation

Hi Aleemsidra!

The first parameter – in_features – used to construct a Linear
must match the number of features in the tensor passed to it. That
is, that tensor should have shape [nBatch, in_features] (where
the Linear doesn’t care about the batch size, nBatch, which can
be anything and vary from call to call).

In a CNN you will typically have some convolutions that are applied
to images of arbitrary size (but that have to be at least as large as
the kernel size and the number of channels must match). These
convolutions preserve the spatial structure of the image.

But when you get to the Linear layers you typically flatten the
processed image – throwing away the spatial structure – so it
becomes a vector of “features.” This is where 9216 comes from.

In your case if you start with images of height = 16 and width = 16,
the Conv2d layers will eat away a little at the edges of the images, but
increase the number channels (to 64). Each processed sample in the
batch will now have 9216 elements and will turn into a vector of 9216
features when flattened.

The MNIST dataset has images of size 28x28, so the numbers
don’t quite work without some further processing. The dataloader
might hypothetically resize the images (for whatever reason). The
network might well have a pooling layer (common in CNNs) that
isn’t shown in your code snippet. If you input images of size 28x28
but also have a 2x2 pooling layer after the convolutions, you will
also end up with 9216 features.

Here’s an illustrative example that uses the numbers from your above
code snippet:

import torch
print (torch.__version__)

nBatch = 32
height = 16
width = 16

conv1 = torch.nn.Conv2d (1, 32, 3, 1)
conv2 = torch.nn.Conv2d (32, 64, 3, 1)

s = torch.randn (nBatch, 1, height, width)
print ('s.shape =', s.shape)
print ('conv1 (s).shape =', conv1 (s).shape)
print ('conv2 (conv1 (s)).shape =', conv2 (conv1 (s)).shape)
print ('torch.flatten (conv2 (conv1 (s)), 1).shape =', torch.flatten (conv2 (conv1 (s)), 1).shape)

height = 28
width = 28

pool = torch.nn.MaxPool2d (2)

t = torch.randn (nBatch, 1, height, width)
print ('t.shape =', t.shape)
print ('conv1 (t).shape =', conv1 (t).shape)
print ('conv2 (conv1 (t)).shape =', conv2 (conv1 (t)).shape)
print ('pool (conv2 (conv1 (t))).shape =', pool (conv2 (conv1 (t))).shape)
print ('torch.flatten (pool (conv2 (conv1 (t))), 1).shape =', torch.flatten (pool (conv2 (conv1 (t))), 1).shape)

Here is its output:

s.shape = torch.Size([32, 1, 16, 16])
conv1 (s).shape = torch.Size([32, 32, 14, 14])
conv2 (conv1 (s)).shape = torch.Size([32, 64, 12, 12])
torch.flatten (conv2 (conv1 (s)), 1).shape = torch.Size([32, 9216])
t.shape = torch.Size([32, 1, 28, 28])
conv1 (t).shape = torch.Size([32, 32, 26, 26])
conv2 (conv1 (t)).shape = torch.Size([32, 64, 24, 24])
pool (conv2 (conv1 (t))).shape = torch.Size([32, 64, 12, 12])
torch.flatten (pool (conv2 (conv1 (t))), 1).shape = torch.Size([32, 9216])

You can see how each Conv2d layer trims a row of pixels off of the
image edges and how flatten() produces a batch of vectors with
9216 features.

Note that such a network architecture (without further logic in it) will
only work on images of the appropriate fixed size (but the batch size
can be arbitrary).


K. Frank