Given a torch module I would like to create a graph where the nodes correspond to torch modules (i.e. layer classes) and an edge between layer u and v means that the output of u eventually reaches the input of v (there might be intermediate functional operations between them).
Note, my definition of an edge isnβt exactly precise. What I really mean is I want to extract a directed graph whoβs transitive closure has the afformentioned property. So in the graph I actually care about all I care about is that if the input of v depends on the output of u, there is a path between u and v.
First I attempted to use torch.jit.trace, but the problem was that I couldnβt figure out how to associate the nodes in the JIT trace to the named layers in the torch module. I had a similar problem when trying to use the backwards pass graph.
It seems that tools that visualize graphs (as discussed in this SO thread python - How do I visualize a net in Pytorch? - Stack Overflow) like torchviz rely on these, but they also donβt seem to be able to grab the layer names they correspond to.
I attempted an alternative pass where I look at the inputs / outputs of layers using a forward hook and connecting up layers where the input had the same id as an output, but this misses connections if there is an intermediate functional operation between sublayer forward passes:
from itertools import count
from collections import defaultdict
import torch
def model_layers(model):
""" Extract named "leaf" layers from a module """
stack = [('', '', model)]
while stack:
prefix, basename, item = stack.pop()
name = '.'.join([p for p in [prefix, basename] if p])
if isinstance(item, torch.nn.modules.conv._ConvNd):
yield name, item
elif isinstance(item, torch.nn.modules.batchnorm._BatchNorm):
yield name, item
elif hasattr(item, 'reset_parameters'):
yield name, item
child_prefix = name
for child_basename, child_item in list(item.named_children())[::-1]:
stack.append((child_prefix, child_basename, child_item))
# Create example network
import torchvision
net = torchvision.models.resnet18()
# The layers are the nodes we care about
named_layers = list(model_layers(net))
inputs = torch.rand(2, 3, 224, 224)
layer_io_ids = {}
counter = count()
# Use hooks to trace the inputs/outputs to extract some of the edges
def _make_layer_hook(name):
def record_layer_io(layer, args, output):
iter_idx = next(counter)
info = {
'iter_idx': iter_idx,
}
if len(args) == 1:
input_tensor = args[0]
info['input_ids'] = [id(input_tensor)]
else:
raise NotImplementedError
input_tensor = None
if isinstance(output, torch.Tensor):
info['output_ids'] = [id(output)]
else:
raise NotImplementedError
layer_io_ids[name] = info
return record_layer_io
for name, layer in named_layers:
layer._forward_hooks.clear()
hook = _make_layer_hook(name)
layer.register_forward_hook(hook)
# Forward pass to call the hooks
outputs = net(inputs)
input_id_to_name = defaultdict(list)
output_id_to_name = defaultdict(list)
for name, info in layer_io_ids.items():
for input_id in info['input_ids']:
input_id_to_name[input_id].append(name)
for output_id in info['output_ids']:
output_id_to_name[output_id].append(name)
# Build the graph with the info we have
import networkx as nx
layer_graph = nx.DiGraph()
for name, info in layer_io_ids.items():
layer_graph.add_node(name)
for name, info in layer_io_ids.items():
output_ids = info['output_ids']
for output_id in output_ids:
if output_id in input_id_to_name:
for other in input_id_to_name[output_id]:
layer_graph.add_edge(name, other)
nx.write_network_text(layer_graph, vertical_chains=1)
results in something close to what I want, but I need to figure out the connections between the current weakly connected components. My first idea was to look at the order in which forward was called (to get a topological sort of the graph Iβm interested in), but it might be the case that a network has branching paths (e.g. multiple heads), so I canβt use that strategy.
βββ conv1
β β½
β bn1
βββ layer1.0.conv1
β β½
β layer1.0.bn1
β β½
β layer1.0.conv2
β β½
β layer1.0.bn2
β β½
β layer1.1.conv1
β β½
β layer1.1.bn1
β β½
β layer1.1.conv2
β β½
β layer1.1.bn2
β βββΌ layer2.0.conv1
β β β½
β β layer2.0.bn1
β β β½
β β layer2.0.conv2
β β β½
β β layer2.0.bn2
β β β½
β β layer2.1.conv1
β β β½
β β layer2.1.bn1 βΎ layer2.0.downsample.1
β β β½
β β layer2.1.conv2
β β β½
β β layer2.1.bn2
β β βββΌ layer3.0.conv1
β β β β½
β β β layer3.0.bn1
β β β β½
β β β layer3.0.conv2
β β β β½
β β β layer3.0.bn2
β β β β½
β β β layer3.1.conv1
β β β β½
β β β layer3.1.bn1 βΎ layer3.0.downsample.1
β β β β½
β β β layer3.1.conv2
β β β β½
β β β layer3.1.bn2
β β β βββΌ layer4.0.conv1
β β β β β½
β β β β layer4.0.bn1
β β β β β½
β β β β layer4.0.conv2
β β β β β½
β β β β layer4.0.bn2
β β β β β½
β β β β layer4.1.conv1
β β β β β½
β β β β layer4.1.bn1 βΎ layer4.0.downsample.1
β β β β β½
β β β β layer4.1.conv2
β β β β β½
β β β β layer4.1.bn2
β β β βββΌ layer4.0.downsample.0
β β β β½
β β β layer4.0.downsample.1
β β β βββΌ ...
β β βββΌ layer3.0.downsample.0
β β β½
β β layer3.0.downsample.1
β β βββΌ ...
β βββΌ layer2.0.downsample.0
β β½
β layer2.0.downsample.1
β βββΌ ...
βββ fc
All in all it seems like the JIT or backwards graph is the only way to really get the information flow topology, but Iβd really like to see it in terms of the original layer names and not the low-level functional components. Is there any way to associate the JIT or backwards graph back to layer names?
EDIT: it looks like torchview: GitHub - mert-kurttutan/torchview: torchview: visualize pytorch models might do this? But maybe it doensβt get the layer names?