Creating a DAG (computational graph) of NN layers

Hi,

I am trying to create a DAG of all the layers/operations in an NN.

An example:

``````class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)

def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out

class SimpleResNet(nn.Module):
def __init__(self):
super(SimpleResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.resblock1 = ResidualBlock(64, 64)
self.resblock2 = ResidualBlock(64, 64)
self.fc = nn.Linear(64, 10)

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.resblock1(x)
x = self.resblock2(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
``````

Here I want each layer as a node of my DAG.
Eg:

• Input of node bn1 will be conv1, output will be relu.
• Inputs of resblock1.relu will be maxpool and resblock1.bn2 , output will be resblock2.conv1

I have tried using .named_modules() and .named_children(), but haven’t been able to make it work.
I’m guessing PyTorch already creates such a computational graph internally. Is there a good way to access it? Any suggestions to implement this would be very helpful.

Thanks!
M

To create a Directed Acyclic Graph (DAG) of all layers/operations in a neural network using PyTorch, you can create a custom DAG class to store nodes and connections between layers. You can then utilize hooks to track the layers and their inputs and outputs. The hooks track input and output tensors and build a `DAG` by adding edges between tensors.

``````from collections import defaultdict

class DAG:
def __init__(self):
self.graph = defaultdict(list)

self.graph[u].append(v)

class CustomDAGBuilder:
def __init__(self, model):
self.model = model
self.dag = DAG()

def register_hooks(self, module):
def forward_hook(module, input, output):
# Add edges to the DAG
for inp in input:
if inp not in self.dag.graph:
self.dag.graph[inp] = []

if not isinstance(module, nn.Sequential) and \
not isinstance(module, nn.ModuleList) and \
not (module == self.model):
module.register_forward_hook(forward_hook)

def build_dag(self):
self.model.apply(self.register_hooks)

def print_dag(self):
for layer, connections in self.dag.graph.items():
print(f'{layer} -> {connections}')

model = SimpleResNet()
dag_builder = CustomDAGBuilder(model)
dag_builder.build_dag()
x = torch.randn(1, 3, 32, 32)
_ = model(x)
dag_builder.print_dag()

``````