About torch.fx.passes.graph_drawer

Hi, today I try to use the torch.fx.passes.graph_drawer to get a vision of graphmodule

like this:

model = torch.load('graph_model_resnet50.pt')  # this is a graphModule from torch.fx.symbolic_trace

from torch.fx import passe

passes.graph_drawer(model, 'renset50')

but there will a error about:

TypeError: 'module' object is not callable

I am not find the reason about thie error.

1 Like

torch.fx.passes.graph_drawer is a module, not a method.
Use passes.graph_drawer.FxGraphDrawer(model, 'resnet50') and it should work.

thanks, @ptrblck

I found the FxGraphDrawer from graph_drawer.py
ps: pytorch/graph_drawer.py at master · pytorch/pytorch · GitHub

But when I use this basic usage

there not get_dot_graph().create_svg() api from FxGraphDrawer

Visualize a torch.fx.Graph with graphviz
Basic usage:
    g = FxGraphDrawer(symbolic_traced, "resnet18")
    with open("a.svg", "w") as f:
        f.write(g.get_dot_graph().create_svg())

and I try:

    g = FxGraphDrawer(symbolic_traced, "resnet18")
    with open("a.svg", "w") as f:
        f.write(g.get_main_dot_graph().__str__())

but the svg file it can not load.

So, how to do it , and success to save a svg file?

Perhaps is the version of the problem :face_with_head_bandage:

This code works for me and creates an svg file:

import torch
import torchvision.models as models
from torch.fx import passes, symbolic_trace

model = models.resnet18()
model = symbolic_trace(model)

g = passes.graph_drawer.FxGraphDrawer(model, 'resnet50')
with open("a.svg", "wb") as f:
    f.write(g.get_dot_graph().create_svg())

Oh, awsome. It can work, thanks @ptrblck.

Good to hear it’s working for you, too!

1 Like