How to delete layer in pretrained model?

The activation is given as [batch_size, out_channels, height, width], where out_channels are the number of filters from the last conv layer.

1 Like

should we also remove the weights of those removed layers? if yes, how should it be done? thank you.

Layers are implemented as nn.Modules, which hold parameters and buffers, if applicable, and define a forward method.
If you remove a layer which contained weights, the weights will also be removed and not used anymore.

1 Like

Hi @ptrblck, I saw a seemingly quick method to delete some layers. Just simply use del resnet_model.fc as done in this project. Is there any side effect of using such a method? Thanks for your comments in advance.

Yes, I would not recommend to delete the modules directly, as this would break the model as seen here:

model = models.resnet18()
del model.fc

out = model(torch.randn(1, 3, 224, 224))
> ModuleAttributeError: 'ResNet' object has no attribute 'fc'

While the .fc module was removed, it’s still used in the forward method, which will raise this error.
In the linked example, the submodules of the resnet are called manually in a sequential manner, so you would have to be careful, as you won’t be able to call the model directly anymore.

3 Likes

Hi. I am struggling to remove the softmax of pretrained resnet18. I am trying to send the CNN output to a transformer encoder-decoder network. The encoder-decoder network expects an input tensor of size [2000]. A similar thread is here.

import torchvision.models as models

class ResnetEncoder(nn.Module):
  def __init__(self):
    super(ResnetEncoder, self).__init__()
    resnet = models.resnet18(pretrained=True)
    modules = list(resnet.children())[::-1]
    self.resnet = nn.Sequential(*modules)
  
  def forward(self, images):
    out = self.resnet(images)  # dimension: (batchsize * n frame, 3, 227, 227)
    out = out.view(-1, 2000)
    return out


model = ResnetEncoder()
input_frames = torch.randn(100, 3, 224, 224)
output = model(x)

RuntimeError: size mismatch, m1: [168000 x 224], m2: [512 x 1000] at /pytorch/aten/src/TH/generic/THTensorMath.cpp:41

There is a size mismatch because the first layer of resnet expects an input of size Z x 512.

ResnetEncoder(
  (resnet): Sequential(
    (0): Linear(in_features=512, out_features=1000, bias=True)
    (1): AdaptiveAvgPool2d(output_size=(1, 1))
1 Like

Update:
It was a stupid mistake in modules = list(resnet.children())[::-1] :angel:
The right syntax is modules = list(resnet.children())[:-1]

I was wondering why the output for these two code is different?
1)



resnet50 = models.resnet50(pretrained=True)
resnet50.fc = Identity()
resnet50.avgpool = Identity()
output = resnet50(torch.randn(1, 3, 224, 224))
print(output.shape)

output:
torch.Size([1, 100352])
and the model resent50 looks like:

resnet50_2 = models.resnet50(pretrained=True)
model = nn.Sequential(*list(resnet50_2.children())[:-2])
print(model(torch.randn(1, 3, 224, 224)).shape)

output:
torch.Size([1, 2048, 7, 7])
and the model looks like:

The first approach would still use the original forward method with the replaced avgpool and fc layers, so this flatten operation would still be applied, while the latter approach would call the modules sequentially and would thus drop the functional flatten call from the forward method.
As you can see, 2048*7*7=100352.

2 Likes

I am trying to cut inception v3 at mixed_5d as suggested in this paper.
Unfortunately I’m still confused, which way is recommended/possible: Is it possible to “stop” the forward function after mixed_5d and if yes, why wouldn’t I want that?
Thank you very much, I’m really desperate…

1 Like

I don’t know if in earlier versions of PyTorch the following works, but in v1.6 deleting a layer is as simple as:

del model.fc

This both removes the layer from model.modules and model.state_dict.
This is also does not create zombie layers, as an Identity layer would do. Simplifying model loading.

I think as @ptrblck mentioned here.
“While the .fc module was removed, it’s still used in the forward method, which will raise [an] error.”
But I’m curious as well,
Would something like
model.classifier = None or model.fc = None
works the same as
model.classifier = nn.Identity() or model.fc = nn.Identity()?

Setting the module to None would still remove it, but might break the forward method for the same reason:

model = models.resnet18()
x = torch.randn(1, 3, 224, 224)
out = model(x)

model.fc = None
out = model(x)
> TypeError: 'NoneType' object is not callable
1 Like

hello dear
how can I remove some layer of net from alexnet model? in pytorch

I am wondering if there exists a cleaner way to do this. Sometimes the model is complex one, and one can not use sequential model.

Keras has a nice way to handle this by using layer.output and layer.input. One can simply connect the output of one layer to the input of another and compile the model.

It seems a bit werid that pytorch does not provide a similar way to achieve this.

Probably the cleanest way to manipulate the usage of the layers and thus the forward execution would be to override the forward method and return the desired (intermediate) outputs.

Hi, everyone
I end up with this network (modified from alexnet). how to get 0-11 (exclude the last maxpool) ?

Sequential(
(0): Sequential(
(0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
(1): ReLU(inplace=True)
(2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(4): ReLU(inplace=True)
(5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(7): ReLU(inplace=True)
(8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(9): ReLU(inplace=True)
(10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
)
)

thank you

I think you can do model[0][:12].
I don’t know why, but as per your output, Sequential is inside Sequential, so you have to double index, first [0] to get into sequential, than [:12] to get 0-11 layers.

Is this method true,I try it ,it works,I need your suggestions,
class ResNetModify(ImageClassificationBase):
def init(self):
super().init()

    self.network = models.resnet18(pretrained=True)
    #Remove last two layers then added flatten layer
    self.network = nn.Sequential(*list(self.network.children())[:-2],nn.Flatten())
    #Input from previous layer [-3]
     num_ftrs=32768
    #added fully connected layer with output number of classes
    self.network.fc = nn.Linear(num_ftrs, len(train_ds.classes))
    
def forward(self, xb):
    return torch.sigmoid(self.network(xb))

Is there a way to keep the names from the original model? I want to use the resent18 as it is and just remove the fc layer. For resnet18, your solution becomes model = nn.Sequential(*list(full_model.children())[:-1]), but this removes the names. Do you think there is any way to use named_children to keep the whole structure of the original resnet, but without the last layer?

See screenshot attached for comparison. I did not capture all layers in the screenshot - the model is big.