Hi, today I try to use the torch.fx.passes.graph_drawer to get a vision of graphmodule
like this:
model = torch.load('graph_model_resnet50.pt') # this is a graphModule from torch.fx.symbolic_trace
from torch.fx import passe
passes.graph_drawer(model, 'renset50')
there not get_dot_graph().create_svg() api from FxGraphDrawer
Visualize a torch.fx.Graph with graphviz
Basic usage:
g = FxGraphDrawer(symbolic_traced, "resnet18")
with open("a.svg", "w") as f:
f.write(g.get_dot_graph().create_svg())
and I try:
g = FxGraphDrawer(symbolic_traced, "resnet18")
with open("a.svg", "w") as f:
f.write(g.get_main_dot_graph().__str__())
but the svg file it can not load.
So, how to do it , and success to save a svg file?
import torch
import torchvision.models as models
from torch.fx import passes, symbolic_trace
model = models.resnet18()
model = symbolic_trace(model)
g = passes.graph_drawer.FxGraphDrawer(model, 'resnet50')
with open("a.svg", "wb") as f:
f.write(g.get_dot_graph().create_svg())