Remove the FC layer from the pretrained resnet50 and save the new model file

Hello,I am tring to build a CNN network,and I want to make a resnet50 which is deleted last fc layer to be my network’s backbone. Then I want to save a model file of resnet50 which is deleted last fc layer.so I can use “my_network.resnet.load_state_dict(torch.load(resnet50_reduce_fclayer.pth))”.

I have used modules.children() to do

But when my_network.resnet.load_state_dict(torch.load(resnet_reduce_fc.pth)),it will show error:

who can tell me how to deal this problem, thank you!

This code is work fine for me.

import torch.nn as nn
import torchvision.models as model
res = model.resnet18(pretrained=True)
res = list(res.children())[:-2]
_res = nn.Sequential(*res)
torch.save(_res.state_dict(), "./res.pth")
models = torch.load("./res.pth")
``
1 Like

yeah, your code is ok, but my code error is when run resnet.load_state_dict()
%E5%BE%AE%E4%BF%A1%E6%88%AA%E5%9B%BE_20190720142513
resnet define:

class ResNet(nn.Module):
    def __init__(self,block,layers):
        super(ResNet,self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size = 3,stride=2,padding=1)

        self.layer1 = self._make_layer(block,64,layers[0])
        self.layer2 = self._make_layer(block,128,layers[1],stride=2)
        self.layer3 = self._make_layer(block,256,layers[2],stride=2)
        self.layer4 = self._make_layer(block,512,layers[3],stride=2)



    
    def _make_layer(self,block,planes,blocks,stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(conv1x1(self.inplanes,planes * block.expansion,stride),
                                            nn.BatchNorm2d(planes * block.expansion),
                                        )
        
        layers = []
        layers.append(block(self.inplanes,planes,stride,downsample))
        self.inplanes = planes * block.expansion

        for _ in range(1,blocks):
            layers.append(block(self.inplanes,planes))
        
        return nn.Sequential(*layers)

    def forward(self,x):
        sources = []

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        sources += [x]

        x = self.layer2(x)
        sources += [x]

        x = self.layer3(x)
        sources += [x]

        x = self.layer4(x)
        sources += [x]

        return sources
​

Hello, I am wondering whether can we use this method when the original network does not run sequentially (in the forward()), such as the CifarNet:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In such circumstances, will nn.Sequential(list(model.children())[:]) change the original computation graph?

Yes, your usage of the nn.Sequential container will change the model and thus break it.
You can print it and will get:

Sequential(
  (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (3): Linear(in_features=400, out_features=120, bias=True)
  (4): Linear(in_features=120, out_features=84, bias=True)
  (5): Linear(in_features=84, out_features=10, bias=True)
)

which is using the pooling layer only once and is also missing the flattening (via x = x.view(...)) and will thus raise a shape mismatch error.