Extracting Feature Vectors Using Pre-Build Models

I’m pretty new to using PyTorch and trying to learn about extracting feature vectors. I have been able to piece together some code that will extract a feature vector from a single image. See below.

import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from PIL import Image


# Load the pretrained model
model = models.resnet18(pretrained=True)

# Use the model object to select the desired layer
layer = model._modules.get('avgpool')

# Set model to evaluation mode

transformations = torchvision.transforms.Compose([
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),

def get_vector(image):
    # Create a PyTorch tensor with the transformed image
    t_img = transformations(image)
    # Create a vector of zeros that will hold our feature vector
    # The 'avgpool' layer has an output size of 512
    my_embedding = torch.zeros(512)

    # Define a function that will copy the output of a layer
    def copy_data(m, i, o):
        my_embedding.copy_(o.flatten())                 # <-- flatten

    # Attach that function to our selected layer
    h = layer.register_forward_hook(copy_data)
    # Run the model on our transformed image
    with torch.no_grad():                               # <-- no_grad context
        model(t_img.unsqueeze(0))                       # <-- unsqueeze
    # Detach our copy function from the layer
    # Return the feature vector
    return my_embedding

pic_vector = get_vector(img)

What I’m curious about is if there is some documentation somewhere that would talk more about the various layers we are able to extract these vectors from. For example, if I wanted to switch to the Resnet 152 model, how could I learn what all the available layers are I could extract from/what the syntax is for each layer? I’ve been looking around for some more documentation on the “_modules.get” piece and I can’t seem to find what I’m looking for.

1 Like

The approach of using forward hooks looks alright.
You could generally use all intermediate activations as your “feature” tensors.
To see all modules in a model, you could use print(model), which would print all modules, or look at the implementation of the model directly (e.g. resnet implementation).