I am trying to understand how to build a NN in pytorch. I am playing with the MNIST dataset. I am not able to understand that how the first parameter in nn.Linear is set to 9216?
The first parameter – in_features – used to construct a Linear
must match the number of features in the tensor passed to it. That
is, that tensor should have shape [nBatch, in_features] (where
the Linear doesn’t care about the batch size, nBatch, which can
be anything and vary from call to call).
In a CNN you will typically have some convolutions that are applied
to images of arbitrary size (but that have to be at least as large as
the kernel size and the number of channels must match). These
convolutions preserve the spatial structure of the image.
But when you get to the Linear layers you typically flatten the
processed image – throwing away the spatial structure – so it
becomes a vector of “features.” This is where 9216 comes from.
In your case if you start with images of height = 16 and width = 16,
the Conv2d layers will eat away a little at the edges of the images, but
increase the number channels (to 64). Each processed sample in the
batch will now have 9216 elements and will turn into a vector of 9216
features when flattened.
The MNIST dataset has images of size 28x28, so the numbers
don’t quite work without some further processing. The dataloader
might hypothetically resize the images (for whatever reason). The
network might well have a pooling layer (common in CNNs) that
isn’t shown in your code snippet. If you input images of size 28x28
but also have a 2x2 pooling layer after the convolutions, you will
also end up with 9216 features.
Here’s an illustrative example that uses the numbers from your above
code snippet:
You can see how each Conv2d layer trims a row of pixels off of the
image edges and how flatten() produces a batch of vectors with 9216 features.
Note that such a network architecture (without further logic in it) will
only work on images of the appropriate fixed size (but the batch size
can be arbitrary).