Create forward function using intermediate layers from pretrained model

I am trying to do some visual attention using an implementation of the EfficientNet found here which is already pretrained. In order to use the features of the attention I need to get specific layers from the EfficienNet model and then use them in my last Linear layer.
However, I do no know how to use get these intermediate layers in feed them to my attention blocks.

class AttentionNet(nn.Module):
    def __init__(self, num_classes):
        self.model = EfficientNet.from_pretrained('efficientnet-b2')

        # This are my attention layers
        self.projector = ProjectorBlock(256, 512)
        self.attn1 = LinearAttentionBlock(in_features=512, normalize_attn=True)
        self.attn2 = LinearAttentionBlock(in_features=512, normalize_attn=True)
        self.attn3 = LinearAttentionBlock(in_features=512, normalize_attn=True)

        #This is my classification layer
        self.classification = nn.Linear(in_features=512*3, out_features=4, bias=True)

        #This is the layer I want to access
        self.inter_1 = self.model._blocks[0]._bn2

    def forward(self, x):
        x = self.model.forward(x)
        # l1 = Get features from self.inter_1

        c1, g1 = self.attn1(self.projector(l1), x)
        x = self.classify(g1)  # batch_sizexnum_classes
        return [x, c1]

The easiest way is probably to subclass the EfficientNet class and override the forward to integrate extract_features into forward and then return what you need.

Best regards


1 Like

For a baseline post-hoc torchvision.model implementation of this idea look here

Or you could use my EfficientNet impl, it supports feature extraction out of the box :slight_smile:

I’m currently refining the feature, so most up to date is on my ‘features’ branch .

It was intended for grabbing features appropriate for feature pyramid network style inputs (obj detection, segmentation, keypoints, etc) as you can see below… so it lets you grab the deepest set of features at each stride level at preset points. Not quite what you want as you seem to be keeping the last FC and have a more specific point in mind…

>>> import timm
>>> m = timm.create_model('efficientnet_b2', pretrained=True, features_only=True, out_indices=(1,4), feature_location='expansion)
>>> m.feature_info.module_name()
['blocks.1.2.conv_pwl', 'blocks.6.1.conv_pwl']
>>> m.feature_info.channels()
[144, 2112]

For your case though, you could utilise the utilities I’ve put together to help with feature extraction… see example below, skipping your projection/attn blocks for clarity

import timm
from timm.models.features import FeatureHooks

class AttentionNet(nn.Module):
    def __init__(self, num_classes):
        self.hook_loc = 'blocks.0.0.bn2'
        self.model = timm.create_model('efficientnet_b2', pretrained=True, num_classes=num_classes)
        self.hooks = FeatureHooks([dict(module=self.hook_loc)], self.model.named_modules())
        # ...

    def forward(self, x):
        x = self.model.forward(x)
        l1 = self.hooks.get_output(x.device)[self.hook_loc]

        # ....
        return [x, l1]

>>> aa = AttentionNet(1000)
>>> o = aa(torch.randn(2, 3, 224, 224))
>>> for x in o:
...   print(x.shape)
torch.Size([2, 1000])
torch.Size([2, 16, 112, 112])