How can I extract intermediate layer output from loaded CNN model?

I think the best approach would be to script the model via torch.jit.script and export this structure, as this graph would contain the execution workflow of the model, which would be missing if e.g. you only export the initialized modules.

Hi @ptrblck, I’d like to know how to extract intermediate layer output when I have multi-inputs. For example

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.cl1 = nn.Linear(25, 60)
        self.cl2 = nn.Linear(40, 16)
        self.fc1 = nn.Linear(76, 120)
        
    def forward(self, x1, x2):
        out1 = F.relu(self.cl1(x1))
        out2 = F.relu(self.cl2(x2))
        x = torch.cat((out1,out2), dim=1)
        x = F.log_softmax(self.fc1(x), dim=1)
        return x

And I want to get the out1. Could you please help me ? Thanks a lot.

You could replace the F.relu operations with their nn.ReLU modules and register a forward hook to the first nn.ReLU module as described before.

1 Like

Hello,

I have created a “recorder” for such goals part of this package:

It is very simple to record from multiple layers of PyTorch models, including CNNs.

An example to record output from all conv layers of VGG16:


model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16', pretrained = True)
# Only conv layers
layer_nr = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28]

# Get layers from model
layers = [list(model.features.named_children())[lnr][1] for lnr in layer_nr]

# Assign a recorder to each layer from modelutils in torchknickknacks
from torchknickknacks import modelutils
recorders = [modelutils.Recorder(layer, record_output = True, backward = False) for layer in layers]

# Pass some data through the model
X = torch.rand(32,3,224,224)
out = model(X)

# Get the output of each recorder (output of each layer)
rec = [r.recording.detach().clone() for r in recorders]

The Recorder class is based on Understanding Pytorch hooks | Kaggle

Hi @ptrblck , I tried to follow this method as suggested for my network.
I want to extract all 4 layer features in a single go: I am unsure if they are overwritten as the layer name is same in SSL. Can you please suggest if my method is correct, if not please suggest me a better method

Thanks:

def forward(self, x, type='SL'):
    if type == 'SL':
      h0 = self.conv1(x) #Image(x) with transforms1
      h0 = relu(self.bn1(h0))
      h0 = self.maxpool(h0)
      h1 = self.layer1(h0)
  ### 1. extract and save features here ###	
      h2 = self.layer2(h1)
      h3 = self.layer3(h2)
      h4 = self.layer4(h3)

      m1_ = self.f_conv1(x) #Image(x) with transforms1
      m1 = m1_ * h1
  ### 2. extract and save features here ###
      m2_ = self.f_conv2(m1)
      m2 = m2_ * h2
      m3_ = self.f_conv3(m2)
      m3 = m3_ * h3
      m4_ = self.f_conv4(m3)
      m4 = m4_ * h4
      out = self.avgpool(m4)
      out = out.view(out.size(0), -1)
      y = self.linear(out)
      return y

    if type=='SSL':
      h0 = self.conv1(x1) #Image(x) with transforms2
      h0 = relu(self.bn1(h0))
      h0 = self.maxpool(h0)
      h1 = self.layer1(h0)
      ### 3. extract and save features here ###
      h2 = self.layer2(h1)
      h3 = self.layer3(h2)
      h4 = self.layer4(h3)
      feat = self.avgpool(h4)  
 
      h01 = self.conv1(x2) #Image(x) with transforms3
      h01 = relu(self.bn1(h01))
      h01 = self.maxpool(h01)
      h11 = self.layer1(h01)
      ### 4.  extract and save features here ###
      h21 = self.layer2(h11)
      h31 = self.layer3(h21)
      h41 = self.layer4(h31)
      feat1 = self.avgpool(h41)

I have registered hook and extracted as as follows

activation = {}
def get_activation(name):
  def hook(model, input, output):
      activation[name] = output.detach()
  return hook

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = Net(device, 10)
model = model.to(device)
model.layer1.register_forward_hook(get_activation('layer1'))
model.f_conv1.register_forward_hook(get_activation('f_conv1'))
x = torch.randn(1, 3, 128, 128)
output = model(x, type='SL')
print(activation['layer1'])
print(activation['f_conv1'])

model.layer1.register_forward_hook(get_activation('layer1'))
model.layer1.register_forward_hook(get_activation('layer1'))
output = model(x1, type='SSL')
print(activation['layer1'])

output = model(x2, type='SSL')
print(activation['layer1'])

Yes, reusing the same key in the dict would replace the old activations.
A better approach would be to use different names in case you want to keep all intermediate activations.

Do you know the module torchvision.models.feature_extraction?

That is what I usually use

Thanks @ptrblck for the confirmation.

In my case, key (layer name) is the same layer from which I am trying to extract the representations, so how do I change the key name, if I want to register layer1, would this work if I change the key inside the get_activation(‘key name’)

 model.layer1.register_forward_hook(get_activation('layer-h11'))
 model.layer1.register_forward_hook(get_activation('layer-h41'))

What is the difference if I return the layers in the forward function from the example network vs using hooks…to save and access them later

def forward(self, x, type='SL'):
  if type == 'SL':
    h0 = self.conv1(x) 
    h0 = relu(self.bn1(h0))
    h1 = self.layer1(h0)
   ..#more code
    m1_ = self.f_conv1(x)
    m1 = m1_ * h1
    out = self.avgpool(m1)
    out = out.view(out.size(0), -1)
    y = self.linear(out)

   return y, h1, m1

Many Thanks

Thanks @Miguel_Campos , I have never used it, I will look into it.

1 Like

Yes, you can change the name in get_activation since it’s used as the key for the activation dict.

You can do either and pick the approach which fits your use case better.

Thank you so much @ptrblck

Hi @ptrblck , just a quick follow-up question: just trying to confirm, can I return multiple outputs without using detach(), would that have any effect on back propagation or anything else?

   def forward(self, x):
      h0 = self.conv1(x) 
      h0 = relu(self.bn1(h0))
      h1 = self.layer1(h0)
      out = self.avgpool(h1)
      out = out.view(out.size(0), -1)
      y = self.linear(out)   
      return  h1, out, y

Many Thanks !!

Yes, you can return multiple tensors.
As long as you don’t call .backward on these additional tensors or any tensors calculated from them, they won’t influence the original gradient calculation.

What is Pytorch equivalent of model.layers[index].output in tensorflow

Can we extract each neuron as variable in any layer of NN model, and apply optimization constriants in each neuron?

PyTorch layers do not store an .output attribute and you can directly get the output tensor via:

output = layer(input)

re: Can we extract each neuron as variable in any layer of NN model, and apply optimization constriants in each neuron?

But I want to add this constraints as layer inside the NN model. So need some functions in pytorch that can give me acces to each neurons as variables. Is it possible?

I assume you are referring to the weights by “neurons”. If so, then yes, you can directly access them via model.layer.weight and model.layer.bias and could add them to your constraints operations.

Is there an example of using jit. script for this sole purpose only (retrieving features from intermediate layers without having to modify the original model ) ?

Scripting a model via torch.jit.script will try to optimize it by analyzing and processing the computation graph (e.g. on the GPU operations could be fused as seen here. Scripting is not used to return additional (intermediate) outputs from the model.