Accessing any feature of Pretrained Net (by Name)

This is sort of an extended version of the question in Accessing intermediate layers of a pretrained network forward?.

Coming from tensorflow (and liking it here, this is a great community!) I am probably just approaching this the wrong way.
I am trying to write a really simple feature extraction script.
Given an input image (or a stack of images) I just want to compute some features using a pretrained vgg19 and dump them into a binary file. No backprop, no bells or whistles, really.

Thing is, which features exactly is a parameter to my script.
Now, in tensorflow, this would essentially be something along the lines of

logits, end_points = vgg19(inputs)
return [end_points[k] for k in targets]

where targets are conv1_1 or fc6 for example.
How do I do something similar in pytorch?

What I have inferred so far would be possible:

  • create a subclass of nn.Module
  • inside that module, load vgg19
  • essentially re-build the entirety of the original vgg19 model using the submodules of the normal vgg19, and adding names to the submodule keys to be able to address them by those names later
  • define forward() to do the forward pass until the last target, keeping the intermediate results in a list and returning them.

This seems awfully complicated for what I’m trying to achieve, so I’m probably approaching this the wrong way… help, please? :slight_smile:

This is what I came up with now:

class VGG19FeatureExtractor(nn.Module):
  def __init__(self, net, targets):
    super(VGG19FeatureExtractor, self).__init__()
    self.features = net.features
    self.classifier = net.classifier
    if isinstance(targets, int):
      self.targets = {targets}  # TODO: make this nice strings
      self.is_list = False
      self.targets = set(targets)  # TODO: make this nice strings
      self.is_list = True
    max_index = len(self.features) + len(self.classifier)
    for t in targets:
      if not 0 <= t < max_index:
        raise ValueError(f'Invalid target index {t}. Should be within [0, {max_index})')

  def forward(self, x):
    results = list()
    for i, module in enumerate(self.features):
      x = module(x)
      if i in self.targets:
        if len(results) == len(self.targets):
          if self.is_list:
            return results
            return results[0]
    x = x.view(x.shape[0], -1)
    i_ = len(self.features)  # classifier numbers are numbered starting from this
    for i, module in enumerate(self.classifier):
      x = module(x)
      if i + i_ in self.targets:
        if len(results) == len(self.targets):
          if self.is_list:
            return results
            return results[0]
    raise RuntimeError('should have returned before here. what happened?!')

This class takes a vgg19 network as input, as well as a single target or a collection of targets (integers).
It computes the target(s) and returns them as a list if appropriate.

While this seems to be doing what I want, I have several issues with it:

  • it is incredibly verbose for what it does.
  • it uses integers as the targets instead of e.g. conv1_2 or pool5
  • it relies on the self.features and self.classifier attributes of the vgg networks (won’t work with e.g. ResNet)
  • the double for-loop makes my heart hurt. This might be mitigated somewhat with a simple Flatten module though.


and yes, I put that string interpolation in there out of principle. Change it if you want to use it in python2 :stuck_out_tongue:

I finally ended up doing this somewhat properly:

Hi, I made a python package to get the intermediate outputs of a model’s submodules, check it out. Pypi link