Deleting the last layer in a network resulting in size mismatch!

I am working on a simple architecture, LeNet, with the following architecture:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

I want to delete the last layer in the network, fc2. However, when I do so using the approach shown below I get a size mismatch!

model = Net().to(device)
new_model = nn.Sequential(*list(model.children())[:-1]).to(device)
data, label = next(iter(train_loader))
data, label = data.to(device), label.to(device)
output = new_model(data)

This is throwing the following error:
RuntimeError: size mismatch, m1: [32 x 36864], m2: [9216 x 128] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:290

How can I delete the last layer without getting this error?

Hello Al!

When you create new_model it doesn’t know about the detailed
structure of your class Net. In particular, it doesn’t know about
Net’s forward() function. (It does know about model.children()).

So I don’t think any of the torch.nn.functional functions used
in model (an instance of your Net) get called when you execute
new_model (data). In particular, F.max_pool2d (x, 2) isn’t
getting called.

Note that the size of the output of self.conv2 is (64*3)**2 = 36864,
the second dimension of m1 in your size-mismatch error. The size of
the input to self.fc1 is 9216, the first dimension of m2. Had
max_pool2d() been called it would have reduced 36864 by a factor
of 2**2 = 4, yielding 9216, eliminating the size-mismatch error.

Either build a New_Net class (of which new_model will be an
instance) with an appropriate forward function, or only use Modules
to build your Net class (and make Net an nn.Sequential). That is,
use Modules like nn.ReLU and nn.MaxPool2d instead of functions
like F.relu() and F.max_pool2d().

Good luck.

K. Frank

1 Like

Great. I thought so. Thank you so much for the informative response!