Creating a DAG (computational graph) of NN layers

Hi,

I am trying to create a DAG of all the layers/operations in an NN.

An example:

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

class SimpleResNet(nn.Module):
    def __init__(self):
        super(SimpleResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.resblock1 = ResidualBlock(64, 64)
        self.resblock2 = ResidualBlock(64, 64)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

Here I want each layer as a node of my DAG.
Eg:

  • Input of node bn1 will be conv1, output will be relu.
  • Inputs of resblock1.relu will be maxpool and resblock1.bn2 , output will be resblock2.conv1

I have tried using .named_modules() and .named_children(), but haven’t been able to make it work.
I’m guessing PyTorch already creates such a computational graph internally. Is there a good way to access it? Any suggestions to implement this would be very helpful.

Thanks!
M

To create a Directed Acyclic Graph (DAG) of all layers/operations in a neural network using PyTorch, you can create a custom DAG class to store nodes and connections between layers. You can then utilize hooks to track the layers and their inputs and outputs. The hooks track input and output tensors and build a DAG by adding edges between tensors.

from collections import defaultdict

class DAG:
    def __init__(self):
        self.graph = defaultdict(list)

    def add_edge(self, u, v):
        self.graph[u].append(v)

class CustomDAGBuilder:
    def __init__(self, model):
        self.model = model
        self.dag = DAG()

    def register_hooks(self, module):
        def forward_hook(module, input, output):
            # Add edges to the DAG
            for inp in input:
                if inp not in self.dag.graph:
                    self.dag.graph[inp] = []
                self.dag.add_edge(inp, output)

        if not isinstance(module, nn.Sequential) and \
           not isinstance(module, nn.ModuleList) and \
           not (module == self.model):
            module.register_forward_hook(forward_hook)

    def build_dag(self):
        self.model.apply(self.register_hooks)

    def print_dag(self):
        for layer, connections in self.dag.graph.items():
            print(f'{layer} -> {connections}')


model = SimpleResNet()
dag_builder = CustomDAGBuilder(model)
dag_builder.build_dag()
x = torch.randn(1, 3, 32, 32)
_ = model(x)
dag_builder.print_dag()

Thanks for the reply.

This approach of using hooks creates a DAG of input and output tensors in the network.
However, I want a DAG of the layers. Each layer (named module) would correspond to a node in the graph - as I mentioned in the example in the original post. Is it possible to achieve this using hooks?

(What I need is the following: I want to assign each layer with a unique ID, and create this DAG. Later during inference, I want to pass an id, and using hooks I want to save the output of only the layer corresponding to the ID).

Thanks again,
M