Removing layers from ResNet pretrained model

Hi everyone! I am trying to do what I did (see below) in VGG16 but in ResNet and GoogleNet. I basically want to remove some of the last layers of these models. I get an error. I would like to use them for feature extraction. I have read that the error could be located in the flatten, but I honestly do not know how to implement it. Please, could anyone help me? Thanks in advance.

class resnet18_fe(nn.Module):
            def __init__(self):
                super(resnet18_fe, self).__init__()
                self.features = nn.Sequential(
                    # removing the last convolutional layer
                    *list(original_model_resnet.features.children())[:-3]
                )
            def forward(self, x):
                x = self.features(x)
                return x

model_3 = resnet18_fe().to(device=device)

The error I get is:


AttributeError Traceback (most recent call last)
in ()
10 return x
11
—> 12 model_3 = resnet18_fe().to(device=device)
13
14 print(model_3)

1 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in getattr(self, name)
946 return modules[name]
947 raise AttributeError("’{}’ object has no attribute ‘{}’".format(
→ 948 type(self).name, name))
949
950 def setattr(self, name: str, value: Union[Tensor, ‘Module’]) → None:

AttributeError: ‘ResNet’ object has no attribute ‘features’

Check if your ResNet model (original_model_resnet) has features attribute.

Thanks for your reply. No, it has not that. I guess that may be the failure but not sure how to implement it.

Can you upload your ResNet code?

My model is just an imported version:

original_model_resnet = models.resnet18(pretrained=True).to(device=device)

Based on torch ResNet here,vision/resnet.py at a75fdd4180683f7953d97ebbcc92d24682690f96 · pytorch/vision · GitHub, try inheritance, such as,

class resnet18_fe(models.resnet.ResNet):
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # erase layers you want
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x
3 Likes

Thanks, I will try it!