Inconsistencies in getting intermediate layer outputs

Hi, I’m trying to get the output from an intermediate layer from resnet50 using two different methods: create_feature_extractor and the forward hook following here.

Here’s the full code that I minimally changed from here:

import torch
from torchvision.models.feature_extraction import create_feature_extractor
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model.eval()

# Download an example image from the pytorch website
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

layer_name = 'layer1.0.bn3' # other layers show the same problem

# Method I
feature_extractor = create_feature_extractor(model, return_nodes=[layer_name])
with torch.no_grad():
    feature1 = feature_extractor(input_batch)

# Method II
feature2 = {}
def get_activation(name):
    def hook(model, input, output):
        feature2[name] = output.detach()
    return hook
with torch.no_grad():
    model_layer = model._modules.get(layer_name.split('.')[0])
    handle = model_layer.register_forward_hook(get_activation(layer_name))
    out = model(input_batch)
    handle.remove()

I compared the results feature1 and feature2, for example feature1[layer_name][0,0,0], feature2[layer_name][0,0,0], and they are very different.

How can I make the result of the forward hook consistent with the feature extractor?

There are a few issues in your code:

  • You are registering the forward hook on another module (model_layer points to model.layer1 while the feature_extractor points to model.layer1[0].bn3).
  • After fixing this issue in your forward hook approach you are storing a reference of the batchnorm layer output, which will be manipulated inplace by the self.relu and by the inplace addition with identity as seen here. Cloning the output fixes the issue.

This should work:

layer_name = 'layer1.0.bn3' # other layers show the same problem

# Method I
feature_extractor = create_feature_extractor(model, return_nodes=[layer_name])
with torch.no_grad():
    feature1 = feature_extractor(input_batch)

# Method II
feature2 = {}
def get_activation(name):
    def hook(model, input, output):
        feature2[name] = output.detach().clone()
    return hook

with torch.no_grad():
    model_layer = model.layer1[0].bn3
    handle = model_layer.register_forward_hook(get_activation(layer_name))
    out = model(input_batch)
    handle.remove()

print((feature1["layer1.0.bn3"] - feature2["layer1.0.bn3"]).abs().max())
# tensor(0., device='cuda:0')
1 Like

Thanks for the clarification. This helps a lot!