Edit: tried to make the question hopefully clearer.
I need to traverse a computation graph in order to plot a diagram of it.
What I currently do is:
I start from my dummy loss scalar, and recursively go into .next_functions
This allows me to “visit” operations and also parameters.
HOWEVER, i don’t manage to visit saved tensors (for example, activations that are saved for the backward pass).
I can see “old” pytorch code that assumes the presence of .saved_tensors, but I don’t encounter this attribute when “traveling” on the graph.
is .saved_tensors still accesible? if not, any suggestion on how to do it ?
I’m looking at the following reference:
snippet here - full link at the bottom of the post.
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
elif hasattr(var, 'variable'):
u = var.variable
name = param_map[id(u)] if params is not None else ''
node_name = '%s\n %s' % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'next_functions'):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, 'saved_tensors'):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var.grad_fn)
note - the “orange nodes” seem to be designed with the intention to display what I want, but they don’t, since .saved_tensors isn’t found.