Extract features from layer of submodule of a model

My model looks like this:

from __future__ import absolute_import

import torch
from torch import nn
from torch.nn import functional as F
import torchvision
    
class hybrid_cnn(nn.Module):
    def __init__(self,**kwargs):
        super(hybrid_cnn,self).__init__()
        resnet = torchvision.models.resnet50(pretrained=True)
        self.base = nn.Sequential(*list(resnet.children())[:-2])

    def forward(self,x)
        x = self.base(x)
        y = self.conv2(self.conv1(x))
        clf_outputs = {}
        num_fcs = 2
        

        return clf_outputs,x,y

I want a Tensor which would have the output coming from ResNet->conv1->conv2 (classifier layer removed) from both augmented 1 and augmented 2.
I have tried with register_forward_hook, but removable_hook object is not callable,
. Is there any efficient way to extract features from submodule (output from conv2 layer from both augmented1 and augmented2) ?

The forward_hook should work:

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

model.fc0.conv2.register_forward_hook(get_activation('fc0.conv2'))
model.fc1.conv2.register_forward_hook(get_activation('fc1.conv2'))

output = model(x)
print(activation['fc0.conv2'])
print(activation['fc0.conv1'])

I’m not sure, what y = self.conv2(self.conv1(x)) in hybrid_cnn is, so I just removed it.

12 Likes

Thanks for answering. I am not sure what’s happening here, I followed your steps:

model = hybrid_convnet2.hybrid_cnn().to(device)
t = torch.randn([64,3,256,128]).cuda()

I am able to get the output of model(t) at this stage, but after adding forward hooks

model.fc0.conv2.register_forward_hook(get_activation('fc0.conv2'))
model.fc1.conv2.register_forward_hook(get_activation('fc1.conv2'))

And then calling model(t), it gives NoneType object is not callable error.

Do you get the same error without registering the hooks?

No, it prints out the correct output before registering hooks

What does print(model.fc0.conv2) return?
After registering the hooks the error is thrown in the forward pass?

Sequential(
  (0): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): BatchNorm2d(2048, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True)
  (2): LeakyReLU(negative_slope=0.1, inplace)
)

I was doing it your way.
Do I have to add hooks in forward pass of hybrid_cnn ?

The hooks should be already registered with hybrid_cnn's augmented modules.
Could you post the complete error message?

Here it is:

Before registering hooks:

I’m not sure, what’s going on.
Could you post your code as a gist so that I can run exactly the same code?

Here it is:

Thanks for the code.
Your return statement is at the wrong place.
get_activation should return hook, not hook itself:
Right way:

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

Your current implementation:

def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
        return hook

Oh my bad, didn’t realize it.

Dictionary remains empty after calling get_activation()

Your code I’ve fixed works for me.

output = model(t)
print(activation['fc0.conv2'].shape)
> torch.Size([1, 2048, 8, 4])
print(activation['fc1.conv2'].shape)
> torch.Size([1, 2048, 8, 4])

Could you check it again?

1 Like

I had one another doubt. I wanted to add forbenius norm of covariances with my cross entropy.
i have tried this:

class f_norm(nn.Module):
    def __init__(self,epsilon,beta):
        super(f_norm,self).__init__()
        self.epsilon = epsilon
        self.beta = beta
    def forward(self, cs,ct):
        epsilon=0.1
        beta = 0.2
        x = self.epsilon*( torch.sqrt((torch.abs(cs-ct))**2)-self.beta)
        y = x.view(-1)
        y = torch.sum(y)
        return y/1000

This doesn’t seem to work, it doesn’t return a Tensor.
I am using this function instead:

def f_norm(cs,ct):
    epsilon=0.1
    beta = 0.2
    x = epsilon*( torch.sqrt((torch.abs(cs-ct))**2)-beta)
    y = x.view(-1)
    y = torch.sum(y)
    return y/1000
loss = cross_entropy(inputs,targets)+f_norm(cs,ct)

If I call loss.backward(), would this work fine ? I read in forums that unless it’s autograd variable it would not be updated. So was having concerns with this implementation.

As long as you stick with PyTorch functions, you are fine.
What’s wrong with the first implementation?

If the backward pass doesn’t work, PyTorch will throw an error.
So if you don’t get an error, everything should be fine.

Okay, got it. The first implementation (f_norm class) doesn’t return a Tensor.

I’m not sure, why your code outputs strange things, but your code is returning a tensor. :wink:
I’ve just removed the unnecessary values (epsilon and beta are defined twice) and it’s working.