Tracing a Graph of Torch Layers

Given a torch module I would like to create a graph where the nodes correspond to torch modules (i.e. layer classes) and an edge between layer u and v means that the output of u eventually reaches the input of v (there might be intermediate functional operations between them).

Note, my definition of an edge isn’t exactly precise. What I really mean is I want to extract a directed graph who’s transitive closure has the afformentioned property. So in the graph I actually care about all I care about is that if the input of v depends on the output of u, there is a path between u and v.

First I attempted to use torch.jit.trace, but the problem was that I couldn’t figure out how to associate the nodes in the JIT trace to the named layers in the torch module. I had a similar problem when trying to use the backwards pass graph.

It seems that tools that visualize graphs (as discussed in this SO thread python - How do I visualize a net in Pytorch? - Stack Overflow) like torchviz rely on these, but they also don’t seem to be able to grab the layer names they correspond to.

I attempted an alternative pass where I look at the inputs / outputs of layers using a forward hook and connecting up layers where the input had the same id as an output, but this misses connections if there is an intermediate functional operation between sublayer forward passes:

        from itertools import count
        from collections import defaultdict
        import torch

        def model_layers(model):
            """ Extract named "leaf" layers from a module """
            stack = [('', '', model)]
            while stack:
                prefix, basename, item = stack.pop()
                name = '.'.join([p for p in [prefix, basename] if p])
                if isinstance(item, torch.nn.modules.conv._ConvNd):
                    yield name, item
                elif isinstance(item, torch.nn.modules.batchnorm._BatchNorm):
                    yield name, item
                elif hasattr(item, 'reset_parameters'):
                    yield name, item

                child_prefix = name
                for child_basename, child_item in list(item.named_children())[::-1]:
                    stack.append((child_prefix, child_basename, child_item))

        # Create example network
        import torchvision
        net = torchvision.models.resnet18()

        # The layers are the nodes we care about
        named_layers = list(model_layers(net))

        inputs = torch.rand(2, 3, 224, 224)

        layer_io_ids = {}
        counter = count()

        # Use hooks to trace the inputs/outputs to extract some of the edges
        def _make_layer_hook(name):
            def record_layer_io(layer, args, output):
                iter_idx = next(counter)
                info = {
                    'iter_idx': iter_idx,
                }
                if len(args) == 1:
                    input_tensor = args[0]
                    info['input_ids'] = [id(input_tensor)]
                else:
                    raise NotImplementedError
                    input_tensor = None

                if isinstance(output, torch.Tensor):
                    info['output_ids'] = [id(output)]
                else:
                    raise NotImplementedError

                layer_io_ids[name] = info
            return record_layer_io

        for name, layer in named_layers:
            layer._forward_hooks.clear()
            hook = _make_layer_hook(name)
            layer.register_forward_hook(hook)

        # Forward pass to call the hooks
        outputs = net(inputs)

        input_id_to_name = defaultdict(list)
        output_id_to_name = defaultdict(list)
        for name, info in layer_io_ids.items():
            for input_id in info['input_ids']:
                input_id_to_name[input_id].append(name)
            for output_id in info['output_ids']:
                output_id_to_name[output_id].append(name)

        # Build the graph with the info we have
        import networkx as nx
        layer_graph = nx.DiGraph()

        for name, info in layer_io_ids.items():
            layer_graph.add_node(name)

        for name, info in layer_io_ids.items():
            output_ids = info['output_ids']
            for output_id in output_ids:
                if output_id in input_id_to_name:
                    for other in input_id_to_name[output_id]:
                        layer_graph.add_edge(name, other)

        nx.write_network_text(layer_graph, vertical_chains=1)

results in something close to what I want, but I need to figure out the connections between the current weakly connected components. My first idea was to look at the order in which forward was called (to get a topological sort of the graph I’m interested in), but it might be the case that a network has branching paths (e.g. multiple heads), so I can’t use that strategy.

β•Ÿβ”€β”€ conv1
β•Ž   β•½
β•Ž   bn1
β•Ÿβ”€β”€ layer1.0.conv1
β•Ž   β•½
β•Ž   layer1.0.bn1
β•Ž   β•½
β•Ž   layer1.0.conv2
β•Ž   β•½
β•Ž   layer1.0.bn2
β•Ž   β•½
β•Ž   layer1.1.conv1
β•Ž   β•½
β•Ž   layer1.1.bn1
β•Ž   β•½
β•Ž   layer1.1.conv2
β•Ž   β•½
β•Ž   layer1.1.bn2
β•Ž   β”œβ”€β•Ό layer2.0.conv1
β•Ž   β”‚   β•½
β•Ž   β”‚   layer2.0.bn1
β•Ž   β”‚   β•½
β•Ž   β”‚   layer2.0.conv2
β•Ž   β”‚   β•½
β•Ž   β”‚   layer2.0.bn2
β•Ž   β”‚   β•½
β•Ž   β”‚   layer2.1.conv1
β•Ž   β”‚   β•½
β•Ž   β”‚   layer2.1.bn1 β•Ύ layer2.0.downsample.1
β•Ž   β”‚   β•½
β•Ž   β”‚   layer2.1.conv2
β•Ž   β”‚   β•½
β•Ž   β”‚   layer2.1.bn2
β•Ž   β”‚   β”œβ”€β•Ό layer3.0.conv1
β•Ž   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   layer3.0.bn1
β•Ž   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   layer3.0.conv2
β•Ž   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   layer3.0.bn2
β•Ž   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   layer3.1.conv1
β•Ž   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   layer3.1.bn1 β•Ύ layer3.0.downsample.1
β•Ž   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   layer3.1.conv2
β•Ž   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   layer3.1.bn2
β•Ž   β”‚   β”‚   β”œβ”€β•Ό layer4.0.conv1
β•Ž   β”‚   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   β”‚   layer4.0.bn1
β•Ž   β”‚   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   β”‚   layer4.0.conv2
β•Ž   β”‚   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   β”‚   layer4.0.bn2
β•Ž   β”‚   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   β”‚   layer4.1.conv1
β•Ž   β”‚   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   β”‚   layer4.1.bn1 β•Ύ layer4.0.downsample.1
β•Ž   β”‚   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   β”‚   layer4.1.conv2
β•Ž   β”‚   β”‚   β”‚   β•½
β•Ž   β”‚   β”‚   β”‚   layer4.1.bn2
β•Ž   β”‚   β”‚   └─╼ layer4.0.downsample.0
β•Ž   β”‚   β”‚       β•½
β•Ž   β”‚   β”‚       layer4.0.downsample.1
β•Ž   β”‚   β”‚       └─╼  ...
β•Ž   β”‚   └─╼ layer3.0.downsample.0
β•Ž   β”‚       β•½
β•Ž   β”‚       layer3.0.downsample.1
β•Ž   β”‚       └─╼  ...
β•Ž   └─╼ layer2.0.downsample.0
β•Ž       β•½
β•Ž       layer2.0.downsample.1
β•Ž       └─╼  ...
╙── fc

All in all it seems like the JIT or backwards graph is the only way to really get the information flow topology, but I’d really like to see it in terms of the original layer names and not the low-level functional components. Is there any way to associate the JIT or backwards graph back to layer names?

EDIT: it looks like torchview: GitHub - mert-kurttutan/torchview: torchview: visualize pytorch models might do this? But maybe it doens’t get the layer names?

You could check out some of the blog posts by @tom

Those example use torch.jit.trace which is indeed not overly intuitive, but the blog posts do give good ideas.

You should also check out torch.fx which seem rather new and simpler / more intuitive than torch.jit.trace. But I haven’t really figured this one out either :).

2 Likes

TorchLens does this, and remains in active development if there’s any other features on your wishlist: