Transition from Conv2d to Linear Layer Equations

Hi everyone,
First post here. Having trouble finding the right resources to understand how to calculate the dimensions required to transition from conv block, to linear block. I have seen several equations which I attempted to implement unsuccessfully:


  1. β€œThe formula for output neuron:
    Output = ((I-K+2P)/S + 1), where
    I - a size of input neuron,
    K - kernel size,
    P - padding,
    S - stride.”

and

  1. β€œπ‘Šβ€²=(π‘Šβˆ’πΉ+2𝑃/𝑆)+1”

The example network that I have been trying to understand is a CNN for CIFAR10 dataset

Below is the third conv layer block, which feeds into a linear layer w/ 4096 as input:

        # Conv Layer block 3
        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=2, stride=2),
    )


    self.fc_layer = nn.Sequential(
        nn.Dropout(p=0.1),
        nn.Linear(4096, 1024),
        nn.ReLU(inplace=True),
        nn.Linear(1024, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.1),
        nn.Linear(512, 10)
    )

I need to figure out the equations/resources/protocol to calculate this transition between Conv and linear. How did we arrive at 4096?

EDIT: I have also used ptrblck’s print-layer (below) for help, but still struggle to understand this transition intuitively.

class Print(nn.Module):
def forward(self, x):
print(x.size())
return x

Any and all help greatly appreciated, Dan

cited:

  1. Linear layer input neurons number calculation after conv2d
  2. https://datascience.stackexchange.com/questions/40906/determining-size-of-fc-layer-after-conv-layer-in-pytorch

Your output formula is missing the dilation and also the subtraction from the kernel size.
The Conv2d docs show you the formula which is used.

That being said, your printed conv layer block would keep the spatial dimensions equal in the first layers, since conv layers with a kernel size of 3 and padding of 1 would not reduce the height or width of the activation. The max pooling layer would halve the spatial dimensions.

Based on the out_channels=256 I thus assume that the input to the conv block would be [batch_size, 128, 8, 8], the output thus [batch_size, 256, 4, 4], which would be flattened to [batch_size, 256*4*4=4096].

The Convolution arithmetic tutorial is a very good website to learn more about how convolution layers perform the window sliding.

2 Likes

Here is some demo-code where I have put in comments and prints to try and explain this as easy as possible:

import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
            
        #image-tensor goes in as batch_sizex3x32x32
        #print-1 will show this state
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        
        #image-tensor is batch_sizex16x32x32 since: (32-3+2*1)/1+1=32
        #print-2 will show this state
        
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        
        #image-tensor is batch_sizex32x16x16 since: (32-3+2*1)/2+1=16
        #print-3 will show this state
        
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        #image-tensor is batch_sizex32x8x8
        #print-4 will show this state

        #now we flatten image tensor to batch_sizex32*8*8 which is batch_sizex2048
        #print-5 will show this state
        self.fc1 = nn.Linear(32 * 8 * 8, 10) #same as: self.fc1 = nn.Linear(2048, 10)
            
            
    def forward(self, x):
        print("print-1:")
        print(x.shape)
        x = F.relu(self.bn1(self.conv1(x)))
        print("print-2:")
        print(x.shape)
        x = F.relu(self.bn2(self.conv2(x)))
        print("print-3:")
        print(x.shape)
        x = self.max_pool(x)
        print("print-4:")
        print(x.shape)
        x = x.view(-1, 32 * 8 * 8)
        print("print-5:")
        print(x.shape)
        x = self.fc1(x)
        return x
    
model = CNN()
x = torch.ones(4, 3, 32, 32)
model(x)

Out:

print-1:
torch.Size([4, 3, 32, 32])
print-2:
torch.Size([4, 16, 32, 32])
print-3:
torch.Size([4, 32, 16, 16])
print-4:
torch.Size([4, 32, 8, 8])
print-5:
torch.Size([4, 2048])
3 Likes

If it is still unclear you should really check out some tutorials on this topic.
Like the one ptrblck suggested:

Or check out this very popular lecture from Stanford University

Thank you both for the help, rly appreciate it

1 Like

Below is a cheeky little utility that helps you avoid doing all the tedious math by piggy-backing on torch’s error messages. The assumption is that net is your Module which is composed of a sequence of input shape-agnostic layers (such as Conv & ReLU & BatchNorm & MaxPool layers), followed at some point by a flattening and a linear layer whose required size you’re trying to figure out. height and width refer to your desired input image shape. Hope this is useful!

try:
    net.forward(torch.rand((1, 3, height, width)))
    print("Image size is compatible with layer sizes.")
except RuntimeError as e:
    e = str(e)
    if e.endswith("Output size is too small"):
        print("Image size is too small.")
    elif "shapes cannot be multiplied" in e:
        required_shape = e[e.index("x") + 1:].split(" ")[0]
        print(f"Linear layer needs to have size: {required_shape}")
    else:
        print(f"Error not understood: {e}")