How to extract features of an image from a trained model

How if I want to use the weights of ResNet34 as pretrained model and delete the last pooling layer and fc layer?
As ResNet class doesn’t use “nn.Sequential”, it seems that there is nowhere to replace the all layers from ResNet with my chosen layers.
Help me~~ Thank you!! :innocent:

original_model = resnet34(pretrained=True)
class myResNet34(nn.Module):
    def __init__(self):
        super(myResNet34, self).__init__()
        self.features = nn.Sequential(*list(original_model.children())[:-3])
    def forward(self, x):
        x = self.features(x)
        return x
model = myResNet34()

As above, does this make sense?

(model.layer[0:7])(input)
did the job for me.

I am trying to extract features of a certain layer of a pretrained model. The fellowing code does work, however, the values of template_feature_map changed and I did nothing of it.

vgg_feature = models.vgg13(pretrained=True).features
    template_feature_map=[]
    def save_template_feature_map(self, input, output):
        template_feature_map.append(output.detach())
        print(template_feature_map)
    template_handle = vgg_feature[5].register_forward_hook(save_template_feature_map)
    vgg_feature(template[0])
    print(template_feature_map)

The output of 6th layer of the model should have negative values, as first print(template_feature_map) shows. But, the negative values which should maintain in second print(template_feature_map) are changed to zeros, I don’t know why. If you know the mechanism of this, please tell me how to keep the negative values.

The values of two print(template_feature_map):

[tensor([[[[-5.7389e-01, -2.7154e+00, -4.0990e+00,  ...,  4.1902e+00,
            3.1757e+00,  2.2461e+00],
          [-2.2217e+00, -4.3395e+00, -6.8158e+00,  ..., -1.4454e+00,
            9.8012e-01, -2.3653e+00],
          [-4.1940e+00, -6.3235e+00, -6.8422e+00,  ..., -2.8329e+00,
            2.5570e+00, -2.7704e+00],
          ...,
          [-3.3250e+00,  1.3792e-01,  5.4926e+00,  ..., -4.1722e+00,
           -6.1008e-01, -2.6037e+00],
          [ 1.5377e+00,  6.0671e-01,  2.0974e+00,  ...,  1.2441e+00,
            1.5033e+00, -2.7246e+00],
          [ 6.8857e-01, -3.5160e-02,  6.7858e-01,  ...,  1.2052e+00,
            1.4533e+00, -1.4160e+00]],

         [[ 6.8798e-01,  1.6971e+00,  2.1629e+00,  ...,  3.1701e-01,
            8.5424e-01,  2.8768e+00],
          [ 1.4013e+00,  2.7217e+00,  2.1476e+00,  ...,  3.1156e+00,
            4.4858e+00,  3.6936e+00],
          [ 3.1807e+00,  2.2245e+00,  2.4665e+00,  ...,  1.3838e+00,
            1.0580e-02, -3.1445e-03],
          ...,
          [-4.7298e+00, -3.3037e+00, -1.2982e+00,  ...,  2.3266e-01,
            6.7711e+00,  3.8166e+00],
          [-4.7972e+00, -5.4591e+00, -2.5201e+00,  ...,  3.7584e+00,
            5.1524e+00,  2.3072e+00],
          [-2.4306e+00, -2.8033e+00, -2.0912e+00,  ...,  1.9888e+00,
            2.0582e+00,  1.9266e+00]],

         [[-4.4257e+00, -4.6331e+00, -3.3580e-03,  ..., -8.2233e+00,
           -7.4645e+00, -1.7361e+00],
          [-4.5593e+00, -8.4195e+00, -8.8428e+00,  ..., -6.7950e+00,
           -1.4665e+01, -2.5335e+00],
          [-2.3481e+00, -3.8543e+00, -3.5965e+00,  ..., -1.5105e+00,
           -1.6923e+01, -5.9852e+00],
          ...,
          [-8.0165e+00,  8.0185e+00,  6.5506e+00,  ...,  5.3241e+00,
            3.3854e+00, -1.6342e+00],
          [-1.3689e+01, -2.2930e+00,  4.7097e+00,  ...,  3.2021e+00,
            2.9208e+00, -8.0228e-01],
          [-1.3055e+01, -1.1470e+01, -8.4442e+00,  ...,  1.8155e-02,
           -6.2866e-02, -2.0333e+00]],

         ...,

         [[ 3.4622e+00, -1.2417e+00, -5.0749e+00,  ...,  5.3184e+00,
            1.4744e+01,  8.3968e+00],
          [-2.7820e+00, -9.1911e+00, -1.1069e+01,  ...,  2.5380e+00,
            9.8336e+00,  4.0623e+00],
          [-3.9794e+00, -1.0140e+01, -9.9133e+00,  ...,  3.0999e+00,
            5.5936e+00,  2.5775e+00],
          ...,
          [ 2.0299e+00,  2.1304e-01, -2.2307e+00,  ...,  1.1388e+01,
            8.8098e+00,  1.8991e+00],
          [ 8.0663e-01, -1.5073e+00,  3.3977e-01,  ...,  8.5316e+00,
            4.9923e+00, -3.6818e-01],
          [-3.5146e+00, -7.2647e+00, -5.4331e+00,  ..., -1.9781e+00,
           -3.4463e+00, -4.9034e+00]],

         [[-3.2915e+00, -7.3263e+00, -6.8458e+00,  ...,  2.3122e+00,
            9.7774e-01, -1.3498e+00],
          [-4.5396e+00, -8.6832e+00, -8.8582e+00,  ...,  7.1535e-02,
           -4.1133e+00, -4.4045e+00],
          [-4.8781e+00, -7.0239e+00, -4.7350e+00,  ..., -3.6954e+00,
           -9.6687e+00, -8.8289e+00],
          ...,
          [-4.7072e+00, -4.4823e-01,  1.7099e+00,  ...,  3.7923e+00,
            1.6887e+00, -4.3305e+00],
          [-5.5120e+00, -3.2324e+00,  2.3594e+00,  ...,  4.6031e+00,
            1.8856e+00, -4.0147e+00],
          [-5.1355e+00, -5.5335e+00, -1.7738e+00,  ...,  1.6159e+00,
           -1.3950e+00, -4.1055e+00]],

         [[-2.0252e+00, -2.3971e+00, -1.6477e+00,  ..., -3.3740e+00,
           -4.9965e+00, -2.1219e+00],
          [-7.6059e-01, -3.3901e-01, -1.8980e-01,  ..., -4.3286e+00,
           -7.1350e+00, -3.9186e+00],
          [ 8.4101e-01,  1.3403e+00,  2.5821e-01,  ..., -5.1847e+00,
           -7.1829e+00, -3.7724e+00],
          ...,
          [-6.0619e+00, -5.6475e+00, -1.6446e+00,  ..., -9.2322e+00,
           -9.1981e+00, -5.5239e+00],
          [-7.4606e+00, -7.6054e+00, -5.8401e+00,  ..., -7.6998e+00,
           -6.4111e+00, -2.9374e+00],
          [-6.4147e+00, -7.2813e+00, -6.1880e+00,  ..., -4.6726e+00,
           -3.1090e+00, -7.8383e-01]]]])]
[tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 4.1902e+00,
           3.1757e+00, 2.2461e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           9.8012e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           2.5570e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 1.3792e-01, 5.4926e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [1.5377e+00, 6.0671e-01, 2.0974e+00,  ..., 1.2441e+00,
           1.5033e+00, 0.0000e+00],
          [6.8857e-01, 0.0000e+00, 6.7858e-01,  ..., 1.2052e+00,
           1.4533e+00, 0.0000e+00]],

         [[6.8798e-01, 1.6971e+00, 2.1629e+00,  ..., 3.1701e-01,
           8.5424e-01, 2.8768e+00],
          [1.4013e+00, 2.7217e+00, 2.1476e+00,  ..., 3.1156e+00,
           4.4858e+00, 3.6936e+00],
          [3.1807e+00, 2.2245e+00, 2.4665e+00,  ..., 1.3838e+00,
           1.0580e-02, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.3266e-01,
           6.7711e+00, 3.8166e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.7584e+00,
           5.1524e+00, 2.3072e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.9888e+00,
           2.0582e+00, 1.9266e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 8.0185e+00, 6.5506e+00,  ..., 5.3241e+00,
           3.3854e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 4.7097e+00,  ..., 3.2021e+00,
           2.9208e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.8155e-02,
           0.0000e+00, 0.0000e+00]],

         ...,

         [[3.4622e+00, 0.0000e+00, 0.0000e+00,  ..., 5.3184e+00,
           1.4744e+01, 8.3968e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.5380e+00,
           9.8336e+00, 4.0623e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 3.0999e+00,
           5.5936e+00, 2.5775e+00],
          ...,
          [2.0299e+00, 2.1304e-01, 0.0000e+00,  ..., 1.1388e+01,
           8.8098e+00, 1.8991e+00],
          [8.0663e-01, 0.0000e+00, 3.3977e-01,  ..., 8.5316e+00,
           4.9923e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 2.3122e+00,
           9.7774e-01, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 7.1535e-02,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 1.7099e+00,  ..., 3.7923e+00,
           1.6887e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 2.3594e+00,  ..., 4.6031e+00,
           1.8856e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.6159e+00,
           0.0000e+00, 0.0000e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [8.4101e-01, 1.3403e+00, 2.5821e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00]]]])]

here is what i did for alexnet as an example:

class AlexNetExtractor(nn.Module):
    def __init__(self, submodule, extracted_layer):
        super(AlexNetExtractor, self).__init__()
        self.submodule = submodule
        self.extracted_layer = extracted_layer
        if self.extracted_layer < 6:
            self.features = self._get_features()

    def forward(self, x):
        if self.extracted_layer < 6:
            x = self.features(x)
        else:
            self.submodule.classifier = self._get_classifier()
            x = self.submodule(x)
        return x

    def _get_features(self):
        index = self._find_index()
        features = nn.Sequential(
            # stop at the layer
            *list(self.submodule.features.children())[:index]
        )
        return features

    def _get_classifier(self):
        index = self._find_index()
        classifier = nn.Sequential(
            # stop at the layer
            *list(self.submodule.classifier.children())[:index]
        )
        return classifier

    def _find_index(self):
        switcher = {
            1: 3,   # from features
            2: 6,
            3: 8,
            4: 10,
            5: 13,
            6: 3,   # from classifier
            7: 6
        }
        return switcher.get(self.extracted_layer)

so, here i give a layer index (1 to 7) to the extractor, and first, it identifies the corresponding component index to extract features from. then it returns the related features. this is working. however, i’m not sure if it is a solid solution since i’m not an expert in pytorch. it seems like the memory consumption is high. so my question is whether the above code is good solution or not. and what is the wrong in terms of memory management/complexity.

Hi, can I ask why you used * before list in:

nn.Sequential(*list(model.classifier.children())[:-1])

The asterisk will unpack the list elements, so that the elements will be passed “one by one” instead of inside the list.

1 Like

Got it, thank you! Can I also ask what a module or a submodule is? That’s what’s returned by children() isn’t it? Or if you’d rather give me a link to a really good pytorch docs, I would totally go for it.

I have a faster rcnn with res101 pretrained model as Head, I trained it with KITTI and now I want to see the activation of all the conv2d of res101 in BasicBlock and compare them when the input is augmented images, would you help me how to save these activation?
Thanks

Double post from here. Let’s stick to a single thread to avoid confusion. :wink:

This is useful for me.

Hi, for VGG16, if I want to extract features after every operation of Relu, the example code will be showed below.

class Vgg16(nn.Module):
    def __init__(self, pretrained=True):
        super(Vgg16, self).__init__()
        self.net = models.vgg16(pretrained).features.eval()
 
    def forward(self, x):
        out = []
        for i in range(len(self.net)):
            x = self.net[i](x)
            if i in [3, 8, 15, 22, 29]:
                # print(self.net[i])
                out.append(x
        return out

For VGG16, we can see the modules use code below

vgg = models.vgg19(pretrained=True).features.eval()
print (vgg)
1 Like

Many times I faced this problem, and I read this thread many more times!
It’s a bit dated now but I thought sharing this could help people in the future.

I recently tidy my pieces of code feature extractor and packaged it as a stand alone python package (pip install torchextractor). The implementation is based upon @Neta_Zmora speudo description on how Distiller solves this problem.

It doesn’t relies on the assumptions that modules are used and declared in the same order; it doesn’t require to override the forward function and it supports nested modules.

import torch
import torchvision
import torchextractor as tx

model = torchvision.models.resnet18(pretrained=True)
model = tx.Extractor(model, ["layer1", "layer2.1.conv1", "layer3.0.downsample.0", "layer4.0"])
dummy_input = torch.rand(7, 3, 224, 224)

model_output, features = model(dummy_input)
feature_shapes = {name: f.shape for name, f in features.items()}
print(feature_shapes)
# {
#   'layer1': torch.Size([1, 64, 56, 56]),
#   'layer2': torch.Size([1, 128, 28, 28]),
#   'layer3': torch.Size([1, 256, 14, 14]),
#   'layer4': torch.Size([1, 512, 7, 7]),
# }

I hope people will find it easier to extract features, add losses or build an additional head to their models.

2 Likes

Thanks for adding this. I am trying to access an intermediate layer of the below model but having a bit of trouble

model1 = torch.hub.load(‘facebookresearch/pytorchvideo’, ‘slow_r50’, pretrained=True)
dummy_input = torch.rand(1,3, 8,256, 256)

model = tx.Extractor(model1, [“4”])