How to extract features of an image from a trained model

@atcold As I said, there is absolutely no need to inspect the graph for your purposes. It’s possible that you won’t find most of the intermediate values, because they’re not needed and have been already freed, and we don’t guarantee any stability w.r.t. adding hooks to the internal graph objects or the internal data structures for now. There’s a lot of work going around autograd and we can’t be limited in what we can change. If anyone wants to develop some kind of helpers like TensorBoard integradion or serialization, let us know and we’ll keep you posted. Otherwise, it’s a bad idea to depend on that.

Forward hooks are a good idea. Feature extractor could be easily adapted to recursively list all submodules with their names (see e.g. the implementation of state_dict()). We’ll be adding such method to nn soon.


OK, it starts to make sense.
I’ll post here some notes, for future reference (there is also a notebook, here).

One thing I’ve figured out, by inspecting the network graph, is that it is made of Variables (or Parameters, which are technically Variables too) and _functions.

Given that h_x = resnet_18(x), I have that h_x.creator is a torch.nn._functions.linear.Linear object, which previous_functions is a tuple of len 3, containing

  1. a torch.autograd._functions.tensor.View object,
  2. a weight matrix of size (1000, 512)
  3. a bias vector of size (1000)

Some of these _functions object have cached values (as long as volatile is False for the input Variable), but one cannot rely on them.

On the other side, by typing print(resnet_18), one can visualise the network’s Modules’ name and respective repr. (The output is quite lengthy, so I will avoid copying it here.)
Some of these Modules have other Modules inside, and the repr makes sure to unroll them out, in insertion order (meaning, when they have been assigned to the respective self super-object) and may not reflect the order with which they have been used, in the forward() method.

Now, let’s have a look for ResNet-18.

>>> resnet_18._modules.keys()
odict_keys(['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool', 'fc'])

We can now register a forward hook to avgpool, and have it return its output Variable.
Let’s check first what this hook gets as parameters.

avgpool_layer = resnet_18._modules.get('avgpool')
h = avgpool_layer.register_forward_hook(
        lambda m, i, o: \
            'm:', type(m),
            '\ni:', type(i),
                '\n   len:', len(i),
                '\n   type:', type(i[0]),
                '\n   data size:', i[0].data.size(),
                '\n   data type:', i[0].data.type(),
            '\no:', type(o),
                '\n   data size:',,
                '\n   data type:',,
h_x = resnet_18(x)

gives us

m: <class 'torch.nn.modules.pooling.AvgPool2d'> 
i: <class 'tuple'> 
   len: 1 
   type: <class 'torch.autograd.variable.Variable'> 
   data size: torch.Size([1, 512, 7, 7]) 
   data type: torch.FloatTensor 
o: <class 'torch.autograd.variable.Variable'> 
   data size: torch.Size([1, 512, 1, 1]) 
   data type: torch.FloatTensor

Sweet. We can now create a Tensor of size 512 and copy over the embedding.

my_embedding = torch.zeros(512)
def fun(m, i, o): my_embedding.copy_(
h = avgpool_layer.register_forward_hook(fun)
h_x = resnet_18(x)

Now the Tensor my_embedding will contain what we were looking for.

The only point for which I am still not that confident, is the connection between a torch.autograd._functions.something.Something object and the Module that did create it. Right now I’m just guessing the paternity.


Modules are users of Functions. There’s no inheritance going on.

Right. So, how would you go for identifying which Module has used which Function?
As of right now I just guess this by the order of insertion in the _modules ordered dictionary.

Functions hold no references to functions, and neither do modules (they are stateless).

Alright, then, how do you know what Module has created what Function, despite guessing it?
I mean, how do I know to which module I shall attach a forward hook in order to capture the input / output at a specific Function node in the graph?

There’s no concept of Function hooks, but capturing the input/output of a particular one can be done by adding some saving logic to forward – and if you need the grad, you can add a Variable hook

Given that the model is already built, you don’t really want to edit its forward() method.
The forward() method does call its Modules from the _modules ordered dictionary attribute. These Modules return a Function.
Hence, one can hook the Module to get a given Function input and output.
The question is still the same: how does one determine which Function has been created by which Module.

I’m not sure whether I am making sense, but it seems pretty logical for a Function to keep a reference to its Module creator’s name. In this way one can draw the network graph and put some names around.

We were discussing this yesterday, on Slack ->

The module already being built has nothing to do with its forward method – you could for instance monkeypatch forward or rewrite it and load a saved model; there’s no sense in which the forward code is an inherent part of the model object hierarchy

Why do you even need to know what Module instantiated which Function? Functions can be instantiated by regular method calls on Variables too. There’s nothing special in the relation of Modules and Functions and no references between them exist.

It’s not logical to me - that design would introduce a dependency of torch.autograd Functions on torch.nn modules, whereas right now it’s reverse (i.e. torch.nn package depends on torch.autograd). This would be leaking logic to other packages unrelated to Modules in any way, and adding cyclic dependencies - that’s bad design. We also have a purely functional API to nn that has no notion of modules. You have to understand that PyTorch is different than Lua Torch - it’s torch.autograd that’s the core of the framework, not torch.nn. New nn is a very very simple wrapper that makes it more convenient to use autograd, but nothing more.

To sum up, there’s no way to automatically find out which module has created which function. There are better ways to solve graph labelling and they will be added in the future.

Thank you, @apaszke, I think I start to understand (sorry, I’m slow…).
Could you give me a hint about these “better ways” you are referring at? I’m just super curious now :grin:

Answering you question “Why do you even need to know what Module instantiated which Function?” is pretty simple: so that I know to which Moudle I should register a forward hook in order to inspect the Tensor's flow in the graph.
Even better, would it make sense to have a register_forward/backward_hook() for the Variable class instead of the current register[_backward]_hook()?

Why do you even care what are the gradients in some part of the graph? You should care what are the gradients w.r.t. some intermediate value or module, not where are they in the graph. The graph is purely an internal representation of you computation, and it doesn’t necessary have to mirror what you’ve done in the code. If you care about grad or outputs of some module, use module hooks, if you care about the grad w.r.t. some Variable attach a hook to a Variable (there already is register_hook).

Hi all,

I think @fmassa’s method for extracting features can only deal with the models which have only one kind input(e.g. images). What if the models have more than one input? (e.g. images and language, and we want measure their distance)

@big_tree use register_forward_hook.


thank you very much!

Hi, can I ask how to tell resnet-18 has 18 layers?
If I do print(resnet_18()), it seems like there are 46 layers in the module.
And I saw the visualised picture in your github, but I still can’t figure out why it has 18 layers, can u give a brief explain?

you need to count the number of Convolution layers, i think.

1 Like

Thanks a lot !:grin:

May I ask if I someone are willing to pull me into the pytorch slack?

I tried to finetune on vgg16 with a different input size (224–>256), I am using this code:

`pretrained_model = torchvision.models.vgg16(pretrained=True)`
`modified_pretrained = nn.Sequential(*list(pretrained_model.features.children())[:-1]) # to relu5_3`

But why this error when I run a forward function? :
RuntimeError: size mismatch at /home/jcc/pytorch/torch/lib/THC/generic/