Print Autograd Graph

Is there a way to visualize the graph of a model similar to what Tensorflow offers?

17 Likes

There will be tensorboard integration in the future.
For the moment, you can use the visualize function from https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py
and an example on how to use it can be found in here

22 Likes

I’ve made a simpler example to visualize resnet-18 by using the visualize.py as @fmassa mentioned.
See https://gist.github.com/wangg12/f11258583ffcc4728eb71adc0f38e832.

1 Like

I tried your code snippet. However, it doesn’t seem to visualize ResNet correctly… I also tried AlexNet, VGG-19, and same story…

What is your result? What is the problem specifically? @zym1010

code:

%matplotlib inline
from graphviz import Digraph
import re
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd import Variable
import torchvision.models as models


def make_dot(var):
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def add_nodes(var):
        if var not in seen:
            if isinstance(var, Variable):
                value = '('+(', ').join(['%d'% v for v in var.size()])+')'
                dot.node(str(id(var)), str(value), fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'previous_functions'):
                for u in var.previous_functions:
                    dot.edge(str(id(u[0])), str(id(var)))
                    add_nodes(u[0])
    add_nodes(var.creator)
    return dot


inputs = torch.randn(1,3,224,224)
resnet18 = models.resnet18()
y = resnet18(Variable(inputs))
print(y)

g = make_dot(y)
g

result:

definitely, the result doesn’t seem like a ResNet… similar things happen for AlexNet, VGG, etc.

3 Likes

Yes, the visualization code is currently broken for convnets because certain layers have C++ implementations that don’t expose the graph pointers to Python. It’ll be fixed after the autograd refactor.

1 Like

This is my result for resnet18().
I can not find any problem. @zym1010

2 Likes

maybe different pytorch version? @jekbradbury

I’m having the same problem as @zym1010 . Is this because I’m using a different pytorch version? If not, can someone suggest how to go about fixing this?

@varun-suresh this problem should be fixed in master.

1 Like

Same problem here. Any solutions?

1 Like

I built PyTorch from source. I now get a KeyError when I try to build the graph. I’m using this script to build my graph. I’m trying to generate a graph for this network.

What’s the key that fails?

I built from the latest master and change var.creator to var.grad_fn (because there is no creator in master’s Variable according to https://github.com/szagoruyko/functional-zoo/blob/master/visualize.py) but it still does not draw like @wangg12 shows.

The script @fmassa showed “worked” with a bit modification bellow, though I think it is not what we need… (I mean not “Threshold…” but “Relu” etc are better.)

from graphviz import Digraph
import torch
from torch.autograd import Variable


def make_dot(var, params):
    """ Produces Graphviz representation of PyTorch autograd graph
    
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    param_map = {id(v): k for k, v in params.items()}
    print(param_map)
    
    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()
    
    def size_to_str(size):
        return '('+(', ').join(['%d'% v for v in size])+')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                node_name = '%s\n %s' % (param_map.get(id(u)), size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot

from torchvision import models
inputs = torch.randn(1,3,224,224)
resnet18 = models.resnet18()
y = resnet18(Variable(inputs))
# print(y)

g = make_dot(y, resnet18.state_dict())
g.view()
3 Likes

My visualization is cut off by BatchNormBackward, which, like jekbradbury suggested, doesn’t properly expose the Function interface. I am using the wheel build, which apparently isn’t in sync with the source repo (I’m still using var.creator here).

Do things work they way they’re supposed to in the source? Should we consider adding a proper interface to all Function objects to support visualization?

In case anyone’s still looking, here is another implementation that’s a bit more robust, as it does not depend on the autograd backward interface, which seems to be undergoing quite a bit of change. It also gets around the problem of BatchNorm C backend not supporting the python interface.

@hyqneuron I’ve had really good results with your vis implementation, but it runs into trouble with functions that can take a varying number of inputs. Are you going to be maintaining and updating the code going forward?

Hello @moskomule & @zym1010,

thank you for code.
l’m wondering if l can do the same with my model ?
l tried the following :

model = crnn.CRNN(32, 1, 37,256, 1).cuda() # it's my model 
print(model) 

CRNN (
  (cnn): Sequential (
    (conv0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu0): ReLU (inplace)
    (pooling0): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU (inplace)
    (pooling1): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (conv2): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
    (relu2): ReLU (inplace)
    (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3): ReLU (inplace)
    (pooling2): MaxPool2d (size=(2, 2), stride=(2, 1), dilation=(1, 1))
    (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (relu4): ReLU (inplace)
    (conv5): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu5): ReLU (inplace)
    (pooling3): MaxPool2d (size=(2, 2), stride=(2, 1), dilation=(1, 1))
    (conv6): Conv2d(512, 512, kernel_size=(2, 2), stride=(1, 1))
    (batchnorm6): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
    (relu6): ReLU (inplace)
  )
  (rnn): Sequential (
    (0): BidirectionalLSTM (
      (rnn): LSTM(512, 256, bidirectional=True)
      (embedding): Linear (512 -> 256)
    )
    (1): BidirectionalLSTM (
      (rnn): LSTM(256, 256, bidirectional=True)
      (embedding): Linear (512 -> 37)
    )
  )
)

However l fail to visualize the graph
l followed exactly the same steps as you’ve defined them :

inputs = torch.randn(1,3,224,224)
model = crnn.CRNN(32, 1, 37,256, 1).cuda()
**y = model(Variable(inputs)) # l got an error at this line** 
print(y)
g = make_dot(y)
g 

after executing y l got the following error

======= Backtrace: =========
/lib/x86_64-linux-gnu/libc.so.6(+0x777e5)[0x7ff2608057e5]
/lib/x86_64-linux-gnu/libc.so.6(+0x8037a)[0x7ff26080e37a]
/lib/x86_64-linux-gnu/libc.so.6(cfree+0x4c)[0x7ff26081253c]
/home/ahmed/anaconda3/envs/cv/lib/python2.7/site-packages/../../libstdc++.so.6(_ZNSt15basic_stringbufIcSt11char_traitsIcESaIcEE8overflowEi+0x13b)[0x7ff24918604b]
/home/ahmed/anaconda3/envs/cv/lib/python2.7/site-packages/../../libstdc++.so.6(_ZNSt15basic_streambufIcSt11char_traitsIcEE6xsputnEPKcl+0x36)[0x7ff24918a1b6]
/home/ahmed/anaconda3/envs/cv/lib/python2.7/site-packages/torch/lib/libshm.so(_ZSt16__ostream_insertIcSt11char_traitsIcEERSt13basic_ostreamIT_T0_ES6_PKS3_l+0x1c5)[0x7ff23b391235]
/home/ahmed/anaconda3/envs/cv/lib/python2.7/site-packages/torch/_C.so(+0x5d2842)[0x7ff23bc12842]
/home/ahmed/anaconda3/envs/cv/lib/python2.7/site-packages/torch/_C.so(+0x5d34ae)[0x7ff23bc134ae]
/home/ahmed/anaconda3/envs/cv/lib/python2.7/site-packages/torch/_C.so(_ZN5torch2nn33SpatialConvolutionMM_updateOutputEPN4thpp6TensorES3_S3_S3_S3_S3_iiiiii+0xb3)[0x7ff23bc271a3]
/home/ahmed/anaconda3/envs/cv/lib/python2.7/site-packages/torch/_C.so(+0x5caf27)[0x7ff23bc0af27]
/home/ahmed/anaconda3/envs/cv/lib/python2.7/site-packages/torch/_C.so(_ZN5torch8autograd11ConvForward5applyERKSt6vectorISt10shared_ptrINS0_8VariableEESaIS5_EE+0x17bf)[0x7ff23bc0f65f]
/home/ahmed/anaconda3/envs/cv/lib/python2.7/site-packages/torch/_C.so(+0x5c191b)[0x7ff23bc0191b]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyObject_Call+0x53)[0x7ff2614cee93]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyEval_EvalFrameEx+0x715d)[0x7ff26158180d]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyEval_EvalCodeEx+0x89e)[0x7ff261583c3e]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyEval_EvalFrameEx+0x8b47)[0x7ff2615831f7]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyEval_EvalCodeEx+0x89e)[0x7ff261583c3e]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(+0x79b68)[0x7ff2614feb68]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyObject_Call+0x53)[0x7ff2614cee93]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyEval_EvalFrameEx+0x61d6)[0x7ff261580886]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyEval_EvalCodeEx+0x89e)[0x7ff261583c3e]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(+0x79a61)[0x7ff2614fea61]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyObject_Call+0x53)[0x7ff2614cee93]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(+0x5c64f)[0x7ff2614e164f]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyObject_Call+0x53)[0x7ff2614cee93]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(+0xba2ac)[0x7ff26153f2ac]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyObject_Call+0x53)[0x7ff2614cee93]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyEval_EvalFrameEx+0x715d)[0x7ff26158180d]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(PyEval_EvalCodeEx+0x89e)[0x7ff261583c3e]
/home/ahmed/anaconda3/envs/cv/bin/../lib/libpython2.7.so.1.0(+0x79b68)[0x7ff2614feb68]

7ff261a9e000-7ff261a9f000 rw-s 133756000 00:06 541                       /dev/nvidiactl
7ff261a9f000-7ff261aa0000 rw-s 9fee8000 00:06 542                        /dev/nvidia0
7ff261aa0000-7ff261aa1000 rw-s 13372f000 00:06 541                       /dev/nvidiactl
7ff261aa1000-7ff261aa2000 rw-s 9fee8000 00:06 542                        /dev/nvidia0
7ff261aa2000-7ff261aa3000 rw-s 133790000 00:06 541                       /dev/nvidiactl
7ff261aa3000-7ff261aa4000 rwxp 00000000 00:00 0 
7ff261aa4000-7ff261aa6000 rw-p 00000000 00:00 0 
7ff261aa6000-7ff261aa7000 r--p 00025000 08:01 2101713                    /lib/x86_64-linux-gnu/ld-2.23.so
7ff261aa7000-7ff261aa8000 rw-p 00026000 08:01 2101713                    /lib/x86_64-linux-gnu/ld-2.23.so
7ff261aa8000-7ff261aa9000 rw-p 00000000 00:00 0 
7fffff3a5000-7fffff3c8000 rwxp 00000000 00:00 0                          [stack]
7fffff3c8000-7fffff3ca000 rw-p 00000000 00:00 0 
7fffff3ed000-7fffff3ef000 r--p 00000000 00:00 0                          [vvar]
7fffff3ef000-7fffff3f1000 r-xp 00000000 00:00 0                          [vdso]
ffffffffff600000-ffffffffff601000 r-xp 00000000 00:00 0                  [vsyscall]
Process finished with exit code 134 (interrupted by signal 6: SIGABRT)

Thank you for your help