How to access to a layer by module name?

I have a ResNet34 model and I want to find all the ReLU layer. I used named_modules() method to get the layers.

for name, layer in model.named_modules():
    if isinstance(layer, nn.ReLU):
        print(name, layer)

And got out put like

0.2 ReLU(inplace=True)
0.4.0.relu ReLU(inplace=True)
0.4.1.relu ReLU(inplace=True)
0.4.2.relu ReLU(inplace=True)

Is there any way that I can use the name (for example 0.4.0.relu) directly to access to a relu layer? Or I have to index the number in [] like model[0][4][0].relu?

7 Likes

Hi

model = torchvision.models.resnet34()
>>> model
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )

You can access the relu followed by conv1

model.relu

Also, If you want to access the ReLU layer in layer1, you can use the following code to access ReLU in basic block 0 and 1.

model.layer1[0].relu
model.layer1[1].relu

You can index the numbers in the name obtained from named_modules using model[].
If you have a string layer1, you have to use it as model.layer1.
nn.Sequential objects are indexable whereas nn.Module objects are not.

I hope this solves your question.

3 Likes

I have the same question as OP which I don’t think was answered by @surya00060

If model.name_modules() contains names for all the different layers, there must be a way to access the layers by using the name string, No?

The name is much more convenient than something like:

model.layer1[0].relu
2 Likes

Hi @edyuan

I think it is not possible to access all layers of PyTorch by their names. If you see the names, it has indices when the layer was created inside nn.Sequential and otherwise has a module name.

for name, layer in model.named_modules():
...     if isinstance(layer, torch.nn.Conv2d):
...             print(name, layer)

The output for this snippet is

conv1 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
layer1.0.conv1 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer1.0.conv2 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer1.1.conv1 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer1.1.conv2 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer2.0.conv1 Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
layer2.0.conv2 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer2.0.downsample.0 Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
layer2.1.conv1 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer2.1.conv2 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer3.0.conv1 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
layer3.0.conv2 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer3.0.downsample.0 Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
layer3.1.conv1 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer3.1.conv2 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer4.0.conv1 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
layer4.0.conv2 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer4.0.downsample.0 Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
layer4.1.conv1 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
layer4.1.conv2 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

If you see the name of all layers, it’s of the format string.number. If there’s a number in your name, you need to access it by model.string.[number].

I hope this clarifies your doubt.

2 Likes

Hi there, I had a somewhat related problem, with the use case of applying some function to specific modules based on their name (as state_dict and named_modules do return some kind of names) instead of based on their type (as I’ve seen everywhere by now, e.g. here).

And I ended up with this hack - calling it hack as I’m pretty new to PyTorch and not sure yet it is 100% safe - just create the names by yourself:

for n, m in model.named_modules():
     m.auto_name = n

You can then filter when iterating over e.g. model.modules(), using .auto_name.

Hope this helps!

2 Likes

5 Likes

If the layers are named you can access them as you described:

for name, layer in model.named_modules():
    if isinstance(layer, nn.ReLU):
        print(name, layer)

    pytorch_layer_obj = getattr(model, name)
1 Like

TL;DR

>>> from functools import reduce
>>> from torchvision.models import resnet34
>>>
>>> def get_module_by_name(module, access_string):
...     names = access_string.split(sep='.')
...     return reduce(getattr, names, module)
>>>
>>> model = resnet34()
>>> get_module_by_name(model, 'layer1.0.relu')
ReLU(inplace=True)

The most confusing point is nn.Sequential defaulting to naming its modules with numbers, that are not valid python identifiers, thus not being reachable by a dot notation like layer1.0.relu, if I take torchvision’s resnet.

But you can work around it.


Accessing direct child by name

If you have your hands on the network implementation, you can choose valid names for the Sequential by using an OrderedDict like shown in the documentation, and access the layers like you would for anything other than Sequential:

>>> model = nn.Sequential(OrderedDict([
...     ('conv1', nn.Conv2d(1,20,5)),
...     ('relu1', nn.ReLU()),
...     ('conv2', nn.Conv2d(20,64,5)),
...     ('relu2', nn.ReLU())]))
>>> model.relu1
ReLU()

As @Inder pointed out, an alternative to the dot notation in order to access attributes is the built-in getattr() function.
But unlike the dot notation, it accepts any string, and it happens to work on Sequential’s default naming when using numerical strings (Sequential is actually built not on a list but on a dict with string keys).

To get back to my first example, doing getattr(model.layer1, '0').relu works.


Accessing nested child by name

Something nice of getattr() is that you can go multiple levels at a time, for example getattr(foo, 'bar.baz') to get to foo.bar.baz.
Unfortunately, it only works if the whole dot chain is made of valid identifiers.
The solution is to use getattr() recursively yourself, and that’s easy enough with reduce():

from functools import reduce
from typing import Union

import torch
from torch import nn


def get_module_by_name(module: Union[torch.Tensor, nn.Module],
                       access_string: str):
    """Retrieve a module nested in another by its access string.

    Works even when there is a Sequential in the module.
    """
    names = access_string.split(sep='.')
    return reduce(getattr, names, module)


if __name__ == '__main__':
    from torchvision.models import resnet34
    
    model = resnet34()
    get_module_by_name(model, 'layer1.0.relu')

Output:

ReLU(inplace=True)
8 Likes
import torch.nn as nn
from typing import Union

def find_layer(model: nn.Module, identifier: Union[str, int]) -> nn.Module:
    """
    Find a layer in a PyTorch model either by its name using dot notation for nested layers or by its index.

    Parameters
    ----------
    model : nn.Module
        Model from which to search for the layer.
    identifier : str or int
        Layer name using dot notation for nested layers or layer index to find in the model.

    Returns
    -------
    nn.Module
        The layer found, or None if no such layer exists.

    Raises
    ------
    ValueError
        If the identifier is neither a string nor an integer.
    """
    # Flatten the model into a list of layers if index is provided
    if isinstance(identifier, int):
        layers = []
        def flatten_model(module):
            for child in module.children():
                if len(list(child.children())) == 0:
                    layers.append(child)
                else:
                    flatten_model(child)
        flatten_model(model)
        if 0 <= identifier < len(layers):
            return layers[identifier]
        return None

    elif isinstance(identifier, str):
        # Access by dot-notated name
        parts = identifier.split('.')
        current_module = model
        try:
            for part in parts:
                current_module = getattr(current_module, part)
            return current_module
        except AttributeError:
            return None
    else:
        raise ValueError(f"Identifier must be either an integer or a string, got {type(identifier)}.")

# Example usage:
import torchvision.models as models

model = models.resnet18(pretrained=True)

# Retrieve a nested layer by dot notation name
layer_by_name = find_layer(model, 'layer1.0.conv1')
if layer_by_name:
    print(f"Retrieved by name: {layer_by_name}")
else:
    print("No such layer found by name.")

# Retrieve layer by index
layer_by_index = find_layer(model, 10)
if layer_by_index:
    print(f"Retrieved by index: {layer_by_index}")
else:
    print("No such layer found by index.")

By this code u can access layers by string and int also…

1 Like