Iterating over nodes in torch._C.Graph

How can I exactly find what are the nodes present in a PyTorch model graph, and what are their inputs?
I tried to fetch a torch._C.Graph object using


scripted=torch.jit.script(MyModel().eval())
frozen_module = torch.jit.freeze(scripted)
print(frozen_module.inlined_graph)

which gave the following output

graph(%self : __torch__.___torch_mangle_2.MyModel,
      %x1.1 : Tensor,
      %x2.1 : Tensor,
      %x3.1 : Tensor):
  %4 : Float(52229:1, 4:52229, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %5 : Float(10:1, 5:10, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %6 : int[] = prim::Constant[value=[0, 0]]()
  %7 : int[] = prim::Constant[value=[2, 2]]()
  %8 : int[] = prim::Constant[value=[1, 1]]()
  %9 : int = prim::Constant[value=2]() 
  %10 : bool = prim::Constant[value=0]()
  %11 : int = prim::Constant[value=1]() # test.py:39:34
  %12 : int = prim::Constant[value=0]() # test.py:39:29
  %13 : int = prim::Constant[value=-1]() # test.py:39:33
  %self.classifier.bias : Float(4:1, requires_grad=0, device=cpu) = prim::Constant[value=0.001 *  2.8424  1.0601 -1.3229  4.2920 [ CPUFloatType{4} ]]()
  %self.features3.0.bias : Float(5:1, requires_grad=0, device=cpu) = prim::Constant[value= 0.0111 -0.0702  0.1396  0.1691  0.1335 [ CPUFloatType{5} ]]()
  %self.features2.0.bias : Float(3:1, requires_grad=0, device=cpu) = prim::Constant[value= 0.3314  0.0165  0.2588 [ CPUFloatType{3} ]]()
  %self.features2.0.weight : Float(3:9, 1:9, 3:3, 3:1, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %self.features1.0.bias : Float(3:1, requires_grad=0, device=cpu) = prim::Constant[value=0.01 *  2.5380 -31.8947 -15.3462 [ CPUFloatType{3} ]]()
  %self.features1.0.weight : Float(3:9, 1:9, 3:3, 3:1, requires_grad=0, device=cpu) = prim::Constant[value=<Tensor>]()
  %input.4 : Tensor = aten::conv2d(%x1.1, %self.features1.0.weight, %self.features1.0.bias, %8, %8, %8, %11) 
  %input.6 : Tensor = aten::max_pool2d(%input.4, %7, %7, %6, %8, %10) 
  %x1.3 : Tensor = aten::relu(%input.6) 
  %input.7 : Tensor = aten::conv2d(%x2.1, %self.features2.0.weight, %self.features2.0.bias, %8, %8, %8, %11) 
  %input.8 : Tensor = aten::max_pool2d(%input.7, %7, %7, %6, %8, %10) 
  %x2.3 : Tensor = aten::relu(%input.8) 
  %26 : int = aten::dim(%x3.1) 
  %27 : bool = aten::eq(%26, %9) 
  %input.3 : Tensor = prim::If(%27) 
    block0():
      %ret.2 : Tensor = aten::addmm(%self.features3.0.bias, %x3.1, %5, %11, %11) 
      -> (%ret.2)
    block1():
      %output.2 : Tensor = aten::matmul(%x3.1, %5) 
      %output.4 : Tensor = aten::add_(%output.2, %self.features3.0.bias, %11) 
      -> (%output.4)
  %x3.3 : Tensor = aten::relu(%input.3) 
  %33 : int = aten::size(%x1.3, %12) 
  %34 : int[] = prim::ListConstruct(%33, %13)
  %x1.6 : Tensor = aten::view(%x1.3, %34) 
  %36 : int = aten::size(%x2.3, %12) 
  %37 : int[] = prim::ListConstruct(%36, %13)
  %x2.6 : Tensor = aten::view(%x2.3, %37) 
  %39 : int = aten::size(%x3.3, %12) 
  %40 : int[] = prim::ListConstruct(%39, %13)
  %x3.6 : Tensor = aten::view(%x3.3, %40)
  %42 : Tensor[] = prim::ListConstruct(%x1.6, %x2.6, %x3.6)
  %x.1 : Tensor = aten::cat(%42, %11) 
  %44 : int = aten::dim(%x.1)
  %45 : bool = aten::eq(%44, %9) 
  %x.3 : Tensor = prim::If(%45) 
    block0():
      %ret.1 : Tensor = aten::addmm(%self.classifier.bias, %x.1, %4, %11, %11) 
      -> (%ret.1)
    block1():
      %output.1 : Tensor = aten::matmul(%x.1, %4) 
      %output.3 : Tensor = aten::add_(%output.1, %self.classifier.bias, %11) 
      -> (%output.3)
  return (%x.3)

But I am not able to iterate or find what exactly are the nodes present within or the inputs that it has. Do suggest if there is any other way to perform the said operation.

A for loop on gr.nodes() should work on a graph gr:

import torch
import torchvision
scripted=torch.jit.script(torchvision.models.resnet18().eval())
frozen_module = torch.jit.freeze(scripted)
gr = frozen_module.inlined_graph
for n in gr.nodes():
    print(n)

Note that it is easy into trouble

  • if you change the graph while it is iterated over,
  • if the graph goes out of scope and disappears (e.g. that’s why I assign it to a variable).
    I tried to make disappearing nodes not crash but I don’t exactly recall the state of the iterator implementation here.

Best regards

Thomas