Create_feature_extractor() does not find specified layers

Hello, I want to create a cnn explainability class using create_feature_extractor() functionality. The problem I am facing is that I cannot reach certain layers in the models, for example (features.7.0 is the last convolutional layer before avg pool):

from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names


model = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.DEFAULT)
model.eval()
return_nodes = ['features.7.0', 'classifier.1']
feature_extractor = create_feature_extractor(model, return_nodes=return_nodes)

This piece of code produces an error:

ValueError: node: 'features.7.0' is not present in model. Hint: use `get_graph_node_names` to make sure the `return_nodes` you specified are present. It may even be that you need to specify `train_return_nodes` and `eval_return_nodes` separately.

As the error points out I ran the get_graph_node_names(model) and I could not find this layer, even though I can do this model.features[7][0], which gives me the last convolution.

If I pass the the return nodes as return_nodes = ['features.7', 'classifier.1'] (i.e. without the .0 part) the feature extractor works, but the problem is that the features.7 is a Conv2dNormActivation which consists of 1 convolution, 1 batch norm and 1 activation function (as seen in the picture)

Is there a possible solution for this problem so that I could get the output of this convolution operation from the feature etractor functionality?

Hi folks.

I have a similar issue with layers whose name ends with _<number>. There is a procedure that automatically renames the layers. I am guessing it’s assuming the prefix was introduced during the traversal of the NN when the same module was encountered multiple times.
Below is a small example that illustrates this:

import torch
import torch.nn as nn
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

class M(nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.Lin_1 = nn.Linear( in_features=64, out_features=32, bias=True, )
        self.ACT_1 = nn.ReLU() # out_features=32
        self.Lin_2 = nn.Linear( in_features=32, out_features=16, bias=True, )
        self.ACT_2 = nn.ReLU()
        self.Lin_3 = nn.Linear( in_features=16, out_features=8, bias=True, )

    def forward(self, X):
        X = self.Lin_1(X)
        X = self.ACT_1(X)
        X = self.Lin_2(X)
        X = self.ACT_2(X)
        X = self.Lin_3(X)
        return X
    
if __name__ == '__main__':
    m = M()
    print(f'output shape = {m(torch.rand((64,))).shape}') # checking that the model is correct
    print(f'get_graph_node_names(m) = {get_graph_node_names(m)}')
    print(f'should return false: {"Lin" in get_graph_node_names(m)[0]}')
    print(f'should return true: {"Lin_3" in get_graph_node_names(m)[0]}')
    ce = create_feature_extractor(m, ['ACT_1'])
    output = ce(torch.rand((64,)))['ACT_1']
    print(f'out features should be 32: {output.shape}')

The output is:

output shape = torch.Size([8])
get_graph_node_names(m) = (['x', 'Lin', 'ACT', 'Lin_1', 'ACT_1', 'Lin_2'], ['x', 'Lin', 'ACT', 'Lin_1', 'ACT_1', 'Lin_2'])
should return false: True
should return true: False
out features should be 32: torch.Size([16])

My current solution is to rename the layers so that they have a different prefix; for instance, I simply add x at the end of their name (it’s annoying because I did not build the original NN). Would there be a way to tell get_graph_node_names not to rename these layers?

1 Like

Hey there Alban,
the “_{0-9}+$” pattern is indeed removed automatically here (to compensate a potential post-fix added by torch.fx) and I don’t see any existing way to prevent this removal, other than your simple trick of renaming the layers.

As for the OP question, the issue comes from the definition of the ConvNormActivation class, which is lacking an explicit __call__ function and is therefore treated as a black box when tracing the graph.
With this change:

class ConvNormActivation(torch.nn.Sequential):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]] = 3,
        stride: Union[int, Tuple[int, ...]] = 1,
        padding: Optional[Union[int, Tuple[int, ...], str]] = None,
        groups: int = 1,
        norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
        activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
        dilation: Union[int, Tuple[int, ...]] = 1,
        inplace: Optional[bool] = True,
        bias: Optional[bool] = None,
        conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
    ) -> None:

        if padding is None:
            if isinstance(kernel_size, int) and isinstance(dilation, int):
                padding = (kernel_size - 1) // 2 * dilation
            else:
                _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
                kernel_size = _make_ntuple(kernel_size, _conv_dim)
                dilation = _make_ntuple(dilation, _conv_dim)
                padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
        if bias is None:
            bias = norm_layer is None

        layers = [
            conv_layer(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                dilation=dilation,
                groups=groups,
                bias=bias,
            )
        ]

        if norm_layer is not None:
            layers.append(norm_layer(out_channels))

        if activation_layer is not None:
            params = {} if inplace is None else {"inplace": inplace}
            layers.append(activation_layer(**params))
        super().__init__(*layers)
        _log_api_usage_once(self)
        self.out_channels = out_channels

        if self.__class__ == ConvNormActivation:
            warnings.warn(
                "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
            )

    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

access to "features.7.0" is possible. Hope it helps!

1 Like

Hello. The way I have solved it is just to use good old hooks :grinning:. I couldn’t find any other workaround about this issue. Anyways, big thanks for the clarification on where the inner issue with the call method was.