How to extract features of an image from a trained model

You can use the torchvision.models package, where you have functions for constructing various vision models, with an option of using pretrained weights. This:

torchvision.models.resnet18(pretrained=True)

will give you a nn.Module with downloaded weights.

7 Likes

To complement @apaszke reply, once you have a trained model, if you want to extract the result of an intermediate layer (say fc7 after the relu), you have a couple of possibilities.

You can either reconstruct the classifier once the model was instantiated, as in the following example:

import torch
import torch.nn as nn
from torchvision import models

model = models.alexnet(pretrained=True)

# remove last fully-connected layer
new_classifier = nn.Sequential(*list(model.classifier.children())[:-1])
model.classifier = new_classifier

Or, if instead you want to extract other parts of the model, you might need to recreate the model structure, and reusing the parts of the pre-trained model in the new model.

import torch
import torch.nn as nn
from torchvision import models

original_model = models.alexnet(pretrained=True)

class AlexNetConv4(nn.Module):
            def __init__(self):
                super(AlexNetConv4, self).__init__()
                self.features = nn.Sequential(
                    # stop at conv4
                    *list(original_model.features.children())[:-3]
                )
            def forward(self, x):
                x = self.features(x)
                return x

model = AlexNetConv4()
90 Likes

@apaszke @fmassa Thank you guys so much!

Is there a way to get values from multiple layers with one forward pass (for neural style transfer etc.)?

In Tensorflow it would be along the lines of

features = sess.run([net.conv1, net.conv2, net.conv3])
7 Likes

@kpar Yes, it’s possible and very easy in PyTorch. All you need to do is to return the multiple outputs that you want to retrieve.
Here is an example

class Net(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(1, 1, 3)
        self.conv2 = nn.Conv2d(1, 1, 3)
        self.conv3 = nn.Conv2d(1, 1, 3)

    def forward(self, x):
        out1 = F.relu(self.conv1(x))
        out2 = F.relu(self.conv2(out1))
        out3 = F.relu(self.conv3(out2))
        return out1, out2, out3
30 Likes

That looks like it’ll do the trick.

Is there a convenient way to fetch the intermediate values when the forward behavior is defined by nn.Sequential()? It seems like right now the only way to compose multiple responses would be to split off all the individual layers and forward the values manually in forward().

Essentially what I want to do is take an existing network (e.g. VGG) and just pick some responses of some layers (conv1_1, pool1, pool2, etc.) and concatenate them into a feature vector.

7 Likes

You could write your own sequential version that keeps track of all intermediate results in a list. Something like

class SelectiveSequential(nn.Module):
    def __init__(self, to_select, modules_dict):
        super(SelectiveSequential, self).__init__()
        for key, module in modules_dict.items():
            self.add_module(key, module)
        self._to_select = to_select
    
    def forward(x):
        list = []
        for name, module in self._modules.iteritems():
            x = module(x)
            if name in self._to_select:
                list.append(x)
        return list

And then you could use it like

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.features = SelectiveSequential(
            ['conv1', 'conv3'],
            {'conv1': nn.Conv2d(1, 1, 3),
             'conv2': nn.Conv2d(1, 1, 3),
             'conv3': nn.Conv2d(1, 1, 3)}
        )

    def forward(self, x):
        return self.features(x)
31 Likes

Thank you for your help

@kpar, another way to extract features layer by layer of a pre-existing neural network is to recreate a new network, adding one by one all the layers from the pre-trained network, plus some “transparent layers” that just extract features:

net = models.alexnet(pretrained=True).features`

class Feature_extractor(nn.module):
    def forward(self, input):
        self.feature = input.clone()
        return input

new_net = nn.Sequential().cuda() # the new network

target_layers = [conv_1, conv_2, conv_4] # layers you want to extract`

i = 1
for layer in list(cnn):
    if isinstance(layer,nn.Conv2d):
        name = "conv_"+str(i)
        art_net.add_module(name,layer)

        if name in target_layers:
            new_net.add_module("extractor_"+str(i),Feature_extractor())

        i+=1

    if isinstance(layer,nn.ReLU):
        name = "relu_"+str(i)
        new_net.add_module(name,layer)

    if isinstance(layer,nn.MaxPool2d):
        name = "pool_"+str(i)
        new_net.add_module(name,layer)

new_net.forward(your_image)
print new_net.extractor_3.feature
1 Like

@alexis-jacq I wouldn’t recommend that. It’s better to keep your models stateless i.e. not hold any of the intermediate states. Otherwise, if you don’t pay enough attention to them, you might end up with problems when you’ll have references to the graphs you don’t need, and they will be only taking up memory.

If you really want to do something like that, I’d recommend this:

class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        self.submodule = submodule

    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
            x = module(x)
            if name in self.extracted_layers:
                outputs += [x]
        return outputs + [x]

This unfortunately uses a private member _modules, but I don’t expect it to change in the near future, and we’ll probably expose an API for iterating over modules with names soon.

18 Likes

Wow! good to know that… Thanks!

That would be really useful!

1 Like

Hmm… all is well for simple networks, like AlexNet.
What if I want to use something a bit more fancy, say ResNet-18, and wish to extract some intermediate representation?
I can fetch the model with torchvision.models.resnet18(pretrained=True). I think the network graph can be reconstructed by forwarding a Variable and checking who’s the parent, but I wouldn’t be sure about how to do this not by hand.
Moreover I think the tuple attribute classes (mapping output class to label) is missing.

Can I register a hook to a specific child, and have it return its current output? In the old Torch I would have model:get(myModuleIndex).output, or a more nested combination, and got myModuleIndex from visual inspection of the model structure (using the __tostring() metamethod).

How do we hack these models? :confused:

OK, I’ve managed to display the graph with from visualize import make_dot (here’s visualize).
Now how can I reach inside the graph?

So, let’s call the output h_x (which is $h_\theta(x)$).
So, I can peek at my embedding if I do the following dirty sequence of operations (which I guess are not recommended).

x = Variable(torch.rand(1, 3, 224, 224))
h_x = resnet_18.forward(x)
last_view = h_x.creator.previous_functions[0][0]
last_pool = last_view.previous_functions[0][0]
embedding = last_pool.saved_tensors[0]

And embedding should correspond to the output of the last Threshold, if I’m not mistaken. Still, this is not generally applicable to every functional block, since not all block cache the input.
So, I’m still after a better way to dissect these nets…

1 Like

Why do you want to inspect the graph? A new instance is going to be created at every forward.

You can register a hook on any Variable using the register_hook method, or on any module - register_backward_hook.

What are you trying to do, and why doesn’t the FeatureExtractor I shown above work for that?

1 Like

Hmm, I thought that the FeatureExtractor you showed above can deal only with sequential models.
The loop

for name, module in self.submodule._modules.items():
    x = module(x)

passes forward the output of each submodule's module, in a sequential manner, no?
Maybe I don’t understand what the method _modules() provides…

Yeah it does, but you can easily adapt it to other network topologies by replacing self.submodule with some other reference. resnet18 is sequential if you only want to inspect the maps between the blocks.

I think @Atcold meant that FeatureExtractor only works for networks with submodules that are the layers. The problem is not solved for networks with “sub-sub-modules” like in resnet18 structure, maybe somewhere we need a recursive pass, in order to access the “leaf-modules” of the graph…

2 Likes

Thank you @alexis-jacq for addressing my specific issue.
I’ve been suggested to use nn.Module.register_forward_hook() to perform this job.
I still have to figure out a smart way to get the reference to specific nn.Modules within a graph, but I think I can play with the visualisation script I posted above.

@albanD, this is why I think the graph integration with your TensorBoard client is essential.
So that I can click around and add some hooks on the fly.

@Atcold The issue with adding graph integration to Crayon is that we need to essentially write an interface between whatever TensorBoard reads and PyTorch (or Torch, Theano, CNTK, etc.).

We managed to quickly write what we currently have because getting TensorBoard to take 2d values / histograms was relatively straightforward, however adding integration for more complex data types is not a one-day job.

After the ICML deadline we’ll give it a look.

1 Like

@atcold As I said, there is absolutely no need to inspect the graph for your purposes. It’s possible that you won’t find most of the intermediate values, because they’re not needed and have been already freed, and we don’t guarantee any stability w.r.t. adding hooks to the internal graph objects or the internal data structures for now. There’s a lot of work going around autograd and we can’t be limited in what we can change. If anyone wants to develop some kind of helpers like TensorBoard integradion or serialization, let us know and we’ll keep you posted. Otherwise, it’s a bad idea to depend on that.

Forward hooks are a good idea. Feature extractor could be easily adapted to recursively list all submodules with their names (see e.g. the implementation of state_dict()). We’ll be adding such method to nn soon.

6 Likes